0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 package org.apache.spark.sql.execution.vectorized;
0018
0019 import java.math.BigDecimal;
0020 import java.math.BigInteger;
0021 import java.nio.charset.StandardCharsets;
0022 import java.sql.Date;
0023 import java.util.HashMap;
0024 import java.util.Iterator;
0025 import java.util.List;
0026 import java.util.Map;
0027
0028 import org.apache.spark.memory.MemoryMode;
0029 import org.apache.spark.sql.Row;
0030 import org.apache.spark.sql.catalyst.InternalRow;
0031 import org.apache.spark.sql.catalyst.util.DateTimeUtils;
0032 import org.apache.spark.sql.types.*;
0033 import org.apache.spark.sql.vectorized.ColumnarArray;
0034 import org.apache.spark.sql.vectorized.ColumnarBatch;
0035 import org.apache.spark.sql.vectorized.ColumnarMap;
0036 import org.apache.spark.unsafe.types.CalendarInterval;
0037 import org.apache.spark.unsafe.types.UTF8String;
0038
0039
0040
0041
0042
0043
0044 public class ColumnVectorUtils {
0045
0046
0047
0048 public static void populate(WritableColumnVector col, InternalRow row, int fieldIdx) {
0049 int capacity = col.capacity;
0050 DataType t = col.dataType();
0051
0052 if (row.isNullAt(fieldIdx)) {
0053 col.putNulls(0, capacity);
0054 } else {
0055 if (t == DataTypes.BooleanType) {
0056 col.putBooleans(0, capacity, row.getBoolean(fieldIdx));
0057 } else if (t == DataTypes.ByteType) {
0058 col.putBytes(0, capacity, row.getByte(fieldIdx));
0059 } else if (t == DataTypes.ShortType) {
0060 col.putShorts(0, capacity, row.getShort(fieldIdx));
0061 } else if (t == DataTypes.IntegerType) {
0062 col.putInts(0, capacity, row.getInt(fieldIdx));
0063 } else if (t == DataTypes.LongType) {
0064 col.putLongs(0, capacity, row.getLong(fieldIdx));
0065 } else if (t == DataTypes.FloatType) {
0066 col.putFloats(0, capacity, row.getFloat(fieldIdx));
0067 } else if (t == DataTypes.DoubleType) {
0068 col.putDoubles(0, capacity, row.getDouble(fieldIdx));
0069 } else if (t == DataTypes.StringType) {
0070 UTF8String v = row.getUTF8String(fieldIdx);
0071 byte[] bytes = v.getBytes();
0072 for (int i = 0; i < capacity; i++) {
0073 col.putByteArray(i, bytes);
0074 }
0075 } else if (t instanceof DecimalType) {
0076 DecimalType dt = (DecimalType)t;
0077 Decimal d = row.getDecimal(fieldIdx, dt.precision(), dt.scale());
0078 if (dt.precision() <= Decimal.MAX_INT_DIGITS()) {
0079 col.putInts(0, capacity, (int)d.toUnscaledLong());
0080 } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
0081 col.putLongs(0, capacity, d.toUnscaledLong());
0082 } else {
0083 final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
0084 byte[] bytes = integer.toByteArray();
0085 for (int i = 0; i < capacity; i++) {
0086 col.putByteArray(i, bytes, 0, bytes.length);
0087 }
0088 }
0089 } else if (t instanceof CalendarIntervalType) {
0090 CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t);
0091 col.getChild(0).putInts(0, capacity, c.months);
0092 col.getChild(1).putLongs(0, capacity, c.microseconds);
0093 } else if (t instanceof DateType) {
0094 col.putInts(0, capacity, row.getInt(fieldIdx));
0095 } else if (t instanceof TimestampType) {
0096 col.putLongs(0, capacity, row.getLong(fieldIdx));
0097 }
0098 }
0099 }
0100
0101
0102
0103
0104
0105
0106 public static int[] toJavaIntArray(ColumnarArray array) {
0107 for (int i = 0; i < array.numElements(); i++) {
0108 if (array.isNullAt(i)) {
0109 throw new RuntimeException("Cannot handle NULL values.");
0110 }
0111 }
0112 return array.toIntArray();
0113 }
0114
0115 public static Map<Integer, Integer> toJavaIntMap(ColumnarMap map) {
0116 int[] keys = toJavaIntArray(map.keyArray());
0117 int[] values = toJavaIntArray(map.valueArray());
0118 assert keys.length == values.length;
0119
0120 Map<Integer, Integer> result = new HashMap<>();
0121 for (int i = 0; i < keys.length; i++) {
0122 result.put(keys[i], values[i]);
0123 }
0124 return result;
0125 }
0126
0127 private static void appendValue(WritableColumnVector dst, DataType t, Object o) {
0128 if (o == null) {
0129 if (t instanceof CalendarIntervalType) {
0130 dst.appendStruct(true);
0131 } else {
0132 dst.appendNull();
0133 }
0134 } else {
0135 if (t == DataTypes.BooleanType) {
0136 dst.appendBoolean((Boolean) o);
0137 } else if (t == DataTypes.ByteType) {
0138 dst.appendByte((Byte) o);
0139 } else if (t == DataTypes.ShortType) {
0140 dst.appendShort((Short) o);
0141 } else if (t == DataTypes.IntegerType) {
0142 dst.appendInt((Integer) o);
0143 } else if (t == DataTypes.LongType) {
0144 dst.appendLong((Long) o);
0145 } else if (t == DataTypes.FloatType) {
0146 dst.appendFloat((Float) o);
0147 } else if (t == DataTypes.DoubleType) {
0148 dst.appendDouble((Double) o);
0149 } else if (t == DataTypes.StringType) {
0150 byte[] b =((String)o).getBytes(StandardCharsets.UTF_8);
0151 dst.appendByteArray(b, 0, b.length);
0152 } else if (t instanceof DecimalType) {
0153 DecimalType dt = (DecimalType) t;
0154 Decimal d = Decimal.apply((BigDecimal) o, dt.precision(), dt.scale());
0155 if (dt.precision() <= Decimal.MAX_INT_DIGITS()) {
0156 dst.appendInt((int) d.toUnscaledLong());
0157 } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
0158 dst.appendLong(d.toUnscaledLong());
0159 } else {
0160 final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
0161 byte[] bytes = integer.toByteArray();
0162 dst.appendByteArray(bytes, 0, bytes.length);
0163 }
0164 } else if (t instanceof CalendarIntervalType) {
0165 CalendarInterval c = (CalendarInterval)o;
0166 dst.appendStruct(false);
0167 dst.getChild(0).appendInt(c.months);
0168 dst.getChild(1).appendInt(c.days);
0169 dst.getChild(2).appendLong(c.microseconds);
0170 } else if (t instanceof DateType) {
0171 dst.appendInt(DateTimeUtils.fromJavaDate((Date)o));
0172 } else {
0173 throw new UnsupportedOperationException("Type " + t);
0174 }
0175 }
0176 }
0177
0178 private static void appendValue(WritableColumnVector dst, DataType t, Row src, int fieldIdx) {
0179 if (t instanceof ArrayType) {
0180 ArrayType at = (ArrayType)t;
0181 if (src.isNullAt(fieldIdx)) {
0182 dst.appendNull();
0183 } else {
0184 List<Object> values = src.getList(fieldIdx);
0185 dst.appendArray(values.size());
0186 for (Object o : values) {
0187 appendValue(dst.arrayData(), at.elementType(), o);
0188 }
0189 }
0190 } else if (t instanceof StructType) {
0191 StructType st = (StructType)t;
0192 if (src.isNullAt(fieldIdx)) {
0193 dst.appendStruct(true);
0194 } else {
0195 dst.appendStruct(false);
0196 Row c = src.getStruct(fieldIdx);
0197 for (int i = 0; i < st.fields().length; i++) {
0198 appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i);
0199 }
0200 }
0201 } else {
0202 appendValue(dst, t, src.get(fieldIdx));
0203 }
0204 }
0205
0206
0207
0208
0209 public static ColumnarBatch toBatch(
0210 StructType schema, MemoryMode memMode, Iterator<Row> row) {
0211 int capacity = 4 * 1024;
0212 WritableColumnVector[] columnVectors;
0213 if (memMode == MemoryMode.OFF_HEAP) {
0214 columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema);
0215 } else {
0216 columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema);
0217 }
0218
0219 int n = 0;
0220 while (row.hasNext()) {
0221 Row r = row.next();
0222 for (int i = 0; i < schema.fields().length; i++) {
0223 appendValue(columnVectors[i], schema.fields()[i].dataType(), r, i);
0224 }
0225 n++;
0226 }
0227 ColumnarBatch batch = new ColumnarBatch(columnVectors);
0228 batch.setNumRows(n);
0229 return batch;
0230 }
0231 }