0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 package org.apache.spark.sql.vectorized;
0018
0019 import java.util.*;
0020
0021 import org.apache.spark.annotation.Evolving;
0022 import org.apache.spark.sql.catalyst.InternalRow;
0023 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
0024 import org.apache.spark.sql.types.*;
0025 import org.apache.spark.unsafe.types.CalendarInterval;
0026 import org.apache.spark.unsafe.types.UTF8String;
0027
0028
0029
0030
0031
0032
0033 @Evolving
0034 public final class ColumnarBatch implements AutoCloseable {
0035 private int numRows;
0036 private final ColumnVector[] columns;
0037
0038
0039 private final ColumnarBatchRow row;
0040
0041
0042
0043
0044
0045 @Override
0046 public void close() {
0047 for (ColumnVector c: columns) {
0048 c.close();
0049 }
0050 }
0051
0052
0053
0054
0055 public Iterator<InternalRow> rowIterator() {
0056 final int maxRows = numRows;
0057 final ColumnarBatchRow row = new ColumnarBatchRow(columns);
0058 return new Iterator<InternalRow>() {
0059 int rowId = 0;
0060
0061 @Override
0062 public boolean hasNext() {
0063 return rowId < maxRows;
0064 }
0065
0066 @Override
0067 public InternalRow next() {
0068 if (rowId >= maxRows) {
0069 throw new NoSuchElementException();
0070 }
0071 row.rowId = rowId++;
0072 return row;
0073 }
0074
0075 @Override
0076 public void remove() {
0077 throw new UnsupportedOperationException();
0078 }
0079 };
0080 }
0081
0082
0083
0084
0085 public void setNumRows(int numRows) {
0086 this.numRows = numRows;
0087 }
0088
0089
0090
0091
0092 public int numCols() { return columns.length; }
0093
0094
0095
0096
0097 public int numRows() { return numRows; }
0098
0099
0100
0101
0102 public ColumnVector column(int ordinal) { return columns[ordinal]; }
0103
0104
0105
0106
0107 public InternalRow getRow(int rowId) {
0108 assert(rowId >= 0 && rowId < numRows);
0109 row.rowId = rowId;
0110 return row;
0111 }
0112
0113 public ColumnarBatch(ColumnVector[] columns) {
0114 this(columns, 0);
0115 }
0116
0117
0118
0119
0120
0121
0122 public ColumnarBatch(ColumnVector[] columns, int numRows) {
0123 this.columns = columns;
0124 this.numRows = numRows;
0125 this.row = new ColumnarBatchRow(columns);
0126 }
0127 }
0128
0129
0130
0131
0132 class ColumnarBatchRow extends InternalRow {
0133 public int rowId;
0134 private final ColumnVector[] columns;
0135
0136 ColumnarBatchRow(ColumnVector[] columns) {
0137 this.columns = columns;
0138 }
0139
0140 @Override
0141 public int numFields() { return columns.length; }
0142
0143 @Override
0144 public InternalRow copy() {
0145 GenericInternalRow row = new GenericInternalRow(columns.length);
0146 for (int i = 0; i < numFields(); i++) {
0147 if (isNullAt(i)) {
0148 row.setNullAt(i);
0149 } else {
0150 DataType dt = columns[i].dataType();
0151 if (dt instanceof BooleanType) {
0152 row.setBoolean(i, getBoolean(i));
0153 } else if (dt instanceof ByteType) {
0154 row.setByte(i, getByte(i));
0155 } else if (dt instanceof ShortType) {
0156 row.setShort(i, getShort(i));
0157 } else if (dt instanceof IntegerType) {
0158 row.setInt(i, getInt(i));
0159 } else if (dt instanceof LongType) {
0160 row.setLong(i, getLong(i));
0161 } else if (dt instanceof FloatType) {
0162 row.setFloat(i, getFloat(i));
0163 } else if (dt instanceof DoubleType) {
0164 row.setDouble(i, getDouble(i));
0165 } else if (dt instanceof StringType) {
0166 row.update(i, getUTF8String(i).copy());
0167 } else if (dt instanceof BinaryType) {
0168 row.update(i, getBinary(i));
0169 } else if (dt instanceof DecimalType) {
0170 DecimalType t = (DecimalType)dt;
0171 row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
0172 } else if (dt instanceof DateType) {
0173 row.setInt(i, getInt(i));
0174 } else if (dt instanceof TimestampType) {
0175 row.setLong(i, getLong(i));
0176 } else {
0177 throw new RuntimeException("Not implemented. " + dt);
0178 }
0179 }
0180 }
0181 return row;
0182 }
0183
0184 @Override
0185 public boolean anyNull() {
0186 throw new UnsupportedOperationException();
0187 }
0188
0189 @Override
0190 public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); }
0191
0192 @Override
0193 public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
0194
0195 @Override
0196 public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
0197
0198 @Override
0199 public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
0200
0201 @Override
0202 public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
0203
0204 @Override
0205 public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
0206
0207 @Override
0208 public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
0209
0210 @Override
0211 public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
0212
0213 @Override
0214 public Decimal getDecimal(int ordinal, int precision, int scale) {
0215 return columns[ordinal].getDecimal(rowId, precision, scale);
0216 }
0217
0218 @Override
0219 public UTF8String getUTF8String(int ordinal) {
0220 return columns[ordinal].getUTF8String(rowId);
0221 }
0222
0223 @Override
0224 public byte[] getBinary(int ordinal) {
0225 return columns[ordinal].getBinary(rowId);
0226 }
0227
0228 @Override
0229 public CalendarInterval getInterval(int ordinal) {
0230 return columns[ordinal].getInterval(rowId);
0231 }
0232
0233 @Override
0234 public ColumnarRow getStruct(int ordinal, int numFields) {
0235 return columns[ordinal].getStruct(rowId);
0236 }
0237
0238 @Override
0239 public ColumnarArray getArray(int ordinal) {
0240 return columns[ordinal].getArray(rowId);
0241 }
0242
0243 @Override
0244 public ColumnarMap getMap(int ordinal) {
0245 return columns[ordinal].getMap(rowId);
0246 }
0247
0248 @Override
0249 public Object get(int ordinal, DataType dataType) {
0250 if (dataType instanceof BooleanType) {
0251 return getBoolean(ordinal);
0252 } else if (dataType instanceof ByteType) {
0253 return getByte(ordinal);
0254 } else if (dataType instanceof ShortType) {
0255 return getShort(ordinal);
0256 } else if (dataType instanceof IntegerType) {
0257 return getInt(ordinal);
0258 } else if (dataType instanceof LongType) {
0259 return getLong(ordinal);
0260 } else if (dataType instanceof FloatType) {
0261 return getFloat(ordinal);
0262 } else if (dataType instanceof DoubleType) {
0263 return getDouble(ordinal);
0264 } else if (dataType instanceof StringType) {
0265 return getUTF8String(ordinal);
0266 } else if (dataType instanceof BinaryType) {
0267 return getBinary(ordinal);
0268 } else if (dataType instanceof DecimalType) {
0269 DecimalType t = (DecimalType) dataType;
0270 return getDecimal(ordinal, t.precision(), t.scale());
0271 } else if (dataType instanceof DateType) {
0272 return getInt(ordinal);
0273 } else if (dataType instanceof TimestampType) {
0274 return getLong(ordinal);
0275 } else if (dataType instanceof ArrayType) {
0276 return getArray(ordinal);
0277 } else if (dataType instanceof StructType) {
0278 return getStruct(ordinal, ((StructType)dataType).fields().length);
0279 } else if (dataType instanceof MapType) {
0280 return getMap(ordinal);
0281 } else {
0282 throw new UnsupportedOperationException("Datatype not supported " + dataType);
0283 }
0284 }
0285
0286 @Override
0287 public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
0288
0289 @Override
0290 public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); }
0291 }