0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.sql.execution.datasources.parquet;
0019
0020 import java.io.IOException;
0021 import java.time.ZoneId;
0022 import java.util.Arrays;
0023 import java.util.List;
0024
0025 import org.apache.hadoop.mapreduce.InputSplit;
0026 import org.apache.hadoop.mapreduce.TaskAttemptContext;
0027 import org.apache.parquet.column.ColumnDescriptor;
0028 import org.apache.parquet.column.page.PageReadStore;
0029 import org.apache.parquet.schema.Type;
0030
0031 import org.apache.spark.memory.MemoryMode;
0032 import org.apache.spark.sql.catalyst.InternalRow;
0033 import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
0034 import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
0035 import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
0036 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
0037 import org.apache.spark.sql.vectorized.ColumnarBatch;
0038 import org.apache.spark.sql.types.StructField;
0039 import org.apache.spark.sql.types.StructType;
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052 public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBase<Object> {
0053
0054
0055 private int capacity;
0056
0057
0058
0059
0060
0061 private int batchIdx = 0;
0062 private int numBatched = 0;
0063
0064
0065
0066
0067
0068 private VectorizedColumnReader[] columnReaders;
0069
0070
0071
0072
0073 private long rowsReturned;
0074
0075
0076
0077
0078 private long totalCountLoadedSoFar = 0;
0079
0080
0081
0082
0083 private boolean[] missingColumns;
0084
0085
0086
0087
0088
0089 private final ZoneId convertTz;
0090
0091
0092
0093
0094 private final String datetimeRebaseMode;
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110 private ColumnarBatch columnarBatch;
0111
0112 private WritableColumnVector[] columnVectors;
0113
0114
0115
0116
0117 private boolean returnColumnarBatch;
0118
0119
0120
0121
0122 private final MemoryMode MEMORY_MODE;
0123
0124 public VectorizedParquetRecordReader(
0125 ZoneId convertTz, String datetimeRebaseMode, boolean useOffHeap, int capacity) {
0126 this.convertTz = convertTz;
0127 this.datetimeRebaseMode = datetimeRebaseMode;
0128 MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP;
0129 this.capacity = capacity;
0130 }
0131
0132
0133 public VectorizedParquetRecordReader(boolean useOffHeap, int capacity) {
0134 this(null, "CORRECTED", useOffHeap, capacity);
0135 }
0136
0137
0138
0139
0140 @Override
0141 public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
0142 throws IOException, InterruptedException, UnsupportedOperationException {
0143 super.initialize(inputSplit, taskAttemptContext);
0144 initializeInternal();
0145 }
0146
0147
0148
0149
0150
0151 @Override
0152 public void initialize(String path, List<String> columns) throws IOException,
0153 UnsupportedOperationException {
0154 super.initialize(path, columns);
0155 initializeInternal();
0156 }
0157
0158 @Override
0159 public void close() throws IOException {
0160 if (columnarBatch != null) {
0161 columnarBatch.close();
0162 columnarBatch = null;
0163 }
0164 super.close();
0165 }
0166
0167 @Override
0168 public boolean nextKeyValue() throws IOException {
0169 resultBatch();
0170
0171 if (returnColumnarBatch) return nextBatch();
0172
0173 if (batchIdx >= numBatched) {
0174 if (!nextBatch()) return false;
0175 }
0176 ++batchIdx;
0177 return true;
0178 }
0179
0180 @Override
0181 public Object getCurrentValue() {
0182 if (returnColumnarBatch) return columnarBatch;
0183 return columnarBatch.getRow(batchIdx - 1);
0184 }
0185
0186 @Override
0187 public float getProgress() {
0188 return (float) rowsReturned / totalRowCount;
0189 }
0190
0191
0192
0193
0194
0195
0196
0197 private void initBatch(
0198 MemoryMode memMode,
0199 StructType partitionColumns,
0200 InternalRow partitionValues) {
0201 StructType batchSchema = new StructType();
0202 for (StructField f: sparkSchema.fields()) {
0203 batchSchema = batchSchema.add(f);
0204 }
0205 if (partitionColumns != null) {
0206 for (StructField f : partitionColumns.fields()) {
0207 batchSchema = batchSchema.add(f);
0208 }
0209 }
0210
0211 if (memMode == MemoryMode.OFF_HEAP) {
0212 columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema);
0213 } else {
0214 columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema);
0215 }
0216 columnarBatch = new ColumnarBatch(columnVectors);
0217 if (partitionColumns != null) {
0218 int partitionIdx = sparkSchema.fields().length;
0219 for (int i = 0; i < partitionColumns.fields().length; i++) {
0220 ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i);
0221 columnVectors[i + partitionIdx].setIsConstant();
0222 }
0223 }
0224
0225
0226 for (int i = 0; i < missingColumns.length; i++) {
0227 if (missingColumns[i]) {
0228 columnVectors[i].putNulls(0, capacity);
0229 columnVectors[i].setIsConstant();
0230 }
0231 }
0232 }
0233
0234 private void initBatch() {
0235 initBatch(MEMORY_MODE, null, null);
0236 }
0237
0238 public void initBatch(StructType partitionColumns, InternalRow partitionValues) {
0239 initBatch(MEMORY_MODE, partitionColumns, partitionValues);
0240 }
0241
0242
0243
0244
0245
0246
0247 public ColumnarBatch resultBatch() {
0248 if (columnarBatch == null) initBatch();
0249 return columnarBatch;
0250 }
0251
0252
0253
0254
0255 public void enableReturningBatches() {
0256 returnColumnarBatch = true;
0257 }
0258
0259
0260
0261
0262 public boolean nextBatch() throws IOException {
0263 for (WritableColumnVector vector : columnVectors) {
0264 vector.reset();
0265 }
0266 columnarBatch.setNumRows(0);
0267 if (rowsReturned >= totalRowCount) return false;
0268 checkEndOfRowGroup();
0269
0270 int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned);
0271 for (int i = 0; i < columnReaders.length; ++i) {
0272 if (columnReaders[i] == null) continue;
0273 columnReaders[i].readBatch(num, columnVectors[i]);
0274 }
0275 rowsReturned += num;
0276 columnarBatch.setNumRows(num);
0277 numBatched = num;
0278 batchIdx = 0;
0279 return true;
0280 }
0281
0282 private void initializeInternal() throws IOException, UnsupportedOperationException {
0283
0284 missingColumns = new boolean[requestedSchema.getFieldCount()];
0285 List<ColumnDescriptor> columns = requestedSchema.getColumns();
0286 List<String[]> paths = requestedSchema.getPaths();
0287 for (int i = 0; i < requestedSchema.getFieldCount(); ++i) {
0288 Type t = requestedSchema.getFields().get(i);
0289 if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) {
0290 throw new UnsupportedOperationException("Complex types not supported.");
0291 }
0292
0293 String[] colPath = paths.get(i);
0294 if (fileSchema.containsPath(colPath)) {
0295 ColumnDescriptor fd = fileSchema.getColumnDescription(colPath);
0296 if (!fd.equals(columns.get(i))) {
0297 throw new UnsupportedOperationException("Schema evolution not supported.");
0298 }
0299 missingColumns[i] = false;
0300 } else {
0301 if (columns.get(i).getMaxDefinitionLevel() == 0) {
0302
0303 throw new IOException("Required column is missing in data file. Col: " +
0304 Arrays.toString(colPath));
0305 }
0306 missingColumns[i] = true;
0307 }
0308 }
0309 }
0310
0311 private void checkEndOfRowGroup() throws IOException {
0312 if (rowsReturned != totalCountLoadedSoFar) return;
0313 PageReadStore pages = reader.readNextRowGroup();
0314 if (pages == null) {
0315 throw new IOException("expecting more rows but reached last block. Read "
0316 + rowsReturned + " out of " + totalRowCount);
0317 }
0318 List<ColumnDescriptor> columns = requestedSchema.getColumns();
0319 List<Type> types = requestedSchema.asGroupType().getFields();
0320 columnReaders = new VectorizedColumnReader[columns.size()];
0321 for (int i = 0; i < columns.size(); ++i) {
0322 if (missingColumns[i]) continue;
0323 columnReaders[i] = new VectorizedColumnReader(columns.get(i), types.get(i).getOriginalType(),
0324 pages.getPageReader(columns.get(i)), convertTz, datetimeRebaseMode);
0325 }
0326 totalCountLoadedSoFar += pages.getRowCount();
0327 }
0328 }