Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 package org.apache.spark.sql.execution.vectorized;
0018 
0019 import java.math.BigDecimal;
0020 import java.math.BigInteger;
0021 import java.nio.charset.StandardCharsets;
0022 import java.sql.Date;
0023 import java.util.HashMap;
0024 import java.util.Iterator;
0025 import java.util.List;
0026 import java.util.Map;
0027 
0028 import org.apache.spark.memory.MemoryMode;
0029 import org.apache.spark.sql.Row;
0030 import org.apache.spark.sql.catalyst.InternalRow;
0031 import org.apache.spark.sql.catalyst.util.DateTimeUtils;
0032 import org.apache.spark.sql.types.*;
0033 import org.apache.spark.sql.vectorized.ColumnarArray;
0034 import org.apache.spark.sql.vectorized.ColumnarBatch;
0035 import org.apache.spark.sql.vectorized.ColumnarMap;
0036 import org.apache.spark.unsafe.types.CalendarInterval;
0037 import org.apache.spark.unsafe.types.UTF8String;
0038 
0039 /**
0040  * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly
0041  * for debugging or other non-performance critical paths.
0042  * These utilities are mostly used to convert ColumnVectors into other formats.
0043  */
0044 public class ColumnVectorUtils {
0045   /**
0046    * Populates the entire `col` with `row[fieldIdx]`
0047    */
0048   public static void populate(WritableColumnVector col, InternalRow row, int fieldIdx) {
0049     int capacity = col.capacity;
0050     DataType t = col.dataType();
0051 
0052     if (row.isNullAt(fieldIdx)) {
0053       col.putNulls(0, capacity);
0054     } else {
0055       if (t == DataTypes.BooleanType) {
0056         col.putBooleans(0, capacity, row.getBoolean(fieldIdx));
0057       } else if (t == DataTypes.ByteType) {
0058         col.putBytes(0, capacity, row.getByte(fieldIdx));
0059       } else if (t == DataTypes.ShortType) {
0060         col.putShorts(0, capacity, row.getShort(fieldIdx));
0061       } else if (t == DataTypes.IntegerType) {
0062         col.putInts(0, capacity, row.getInt(fieldIdx));
0063       } else if (t == DataTypes.LongType) {
0064         col.putLongs(0, capacity, row.getLong(fieldIdx));
0065       } else if (t == DataTypes.FloatType) {
0066         col.putFloats(0, capacity, row.getFloat(fieldIdx));
0067       } else if (t == DataTypes.DoubleType) {
0068         col.putDoubles(0, capacity, row.getDouble(fieldIdx));
0069       } else if (t == DataTypes.StringType) {
0070         UTF8String v = row.getUTF8String(fieldIdx);
0071         byte[] bytes = v.getBytes();
0072         for (int i = 0; i < capacity; i++) {
0073           col.putByteArray(i, bytes);
0074         }
0075       } else if (t instanceof DecimalType) {
0076         DecimalType dt = (DecimalType)t;
0077         Decimal d = row.getDecimal(fieldIdx, dt.precision(), dt.scale());
0078         if (dt.precision() <= Decimal.MAX_INT_DIGITS()) {
0079           col.putInts(0, capacity, (int)d.toUnscaledLong());
0080         } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
0081           col.putLongs(0, capacity, d.toUnscaledLong());
0082         } else {
0083           final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
0084           byte[] bytes = integer.toByteArray();
0085           for (int i = 0; i < capacity; i++) {
0086             col.putByteArray(i, bytes, 0, bytes.length);
0087           }
0088         }
0089       } else if (t instanceof CalendarIntervalType) {
0090         CalendarInterval c = (CalendarInterval)row.get(fieldIdx, t);
0091         col.getChild(0).putInts(0, capacity, c.months);
0092         col.getChild(1).putLongs(0, capacity, c.microseconds);
0093       } else if (t instanceof DateType) {
0094         col.putInts(0, capacity, row.getInt(fieldIdx));
0095       } else if (t instanceof TimestampType) {
0096         col.putLongs(0, capacity, row.getLong(fieldIdx));
0097       }
0098     }
0099   }
0100 
0101   /**
0102    * Returns the array data as the java primitive array.
0103    * For example, an array of IntegerType will return an int[].
0104    * Throws exceptions for unhandled schemas.
0105    */
0106   public static int[] toJavaIntArray(ColumnarArray array) {
0107     for (int i = 0; i < array.numElements(); i++) {
0108       if (array.isNullAt(i)) {
0109         throw new RuntimeException("Cannot handle NULL values.");
0110       }
0111     }
0112     return array.toIntArray();
0113   }
0114 
0115   public static Map<Integer, Integer> toJavaIntMap(ColumnarMap map) {
0116     int[] keys = toJavaIntArray(map.keyArray());
0117     int[] values = toJavaIntArray(map.valueArray());
0118     assert keys.length == values.length;
0119 
0120     Map<Integer, Integer> result = new HashMap<>();
0121     for (int i = 0; i < keys.length; i++) {
0122       result.put(keys[i], values[i]);
0123     }
0124     return result;
0125   }
0126 
0127   private static void appendValue(WritableColumnVector dst, DataType t, Object o) {
0128     if (o == null) {
0129       if (t instanceof CalendarIntervalType) {
0130         dst.appendStruct(true);
0131       } else {
0132         dst.appendNull();
0133       }
0134     } else {
0135       if (t == DataTypes.BooleanType) {
0136         dst.appendBoolean((Boolean) o);
0137       } else if (t == DataTypes.ByteType) {
0138         dst.appendByte((Byte) o);
0139       } else if (t == DataTypes.ShortType) {
0140         dst.appendShort((Short) o);
0141       } else if (t == DataTypes.IntegerType) {
0142         dst.appendInt((Integer) o);
0143       } else if (t == DataTypes.LongType) {
0144         dst.appendLong((Long) o);
0145       } else if (t == DataTypes.FloatType) {
0146         dst.appendFloat((Float) o);
0147       } else if (t == DataTypes.DoubleType) {
0148         dst.appendDouble((Double) o);
0149       } else if (t == DataTypes.StringType) {
0150         byte[] b =((String)o).getBytes(StandardCharsets.UTF_8);
0151         dst.appendByteArray(b, 0, b.length);
0152       } else if (t instanceof DecimalType) {
0153         DecimalType dt = (DecimalType) t;
0154         Decimal d = Decimal.apply((BigDecimal) o, dt.precision(), dt.scale());
0155         if (dt.precision() <= Decimal.MAX_INT_DIGITS()) {
0156           dst.appendInt((int) d.toUnscaledLong());
0157         } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
0158           dst.appendLong(d.toUnscaledLong());
0159         } else {
0160           final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
0161           byte[] bytes = integer.toByteArray();
0162           dst.appendByteArray(bytes, 0, bytes.length);
0163         }
0164       } else if (t instanceof CalendarIntervalType) {
0165         CalendarInterval c = (CalendarInterval)o;
0166         dst.appendStruct(false);
0167         dst.getChild(0).appendInt(c.months);
0168         dst.getChild(1).appendInt(c.days);
0169         dst.getChild(2).appendLong(c.microseconds);
0170       } else if (t instanceof DateType) {
0171         dst.appendInt(DateTimeUtils.fromJavaDate((Date)o));
0172       } else {
0173         throw new UnsupportedOperationException("Type " + t);
0174       }
0175     }
0176   }
0177 
0178   private static void appendValue(WritableColumnVector dst, DataType t, Row src, int fieldIdx) {
0179     if (t instanceof ArrayType) {
0180       ArrayType at = (ArrayType)t;
0181       if (src.isNullAt(fieldIdx)) {
0182         dst.appendNull();
0183       } else {
0184         List<Object> values = src.getList(fieldIdx);
0185         dst.appendArray(values.size());
0186         for (Object o : values) {
0187           appendValue(dst.arrayData(), at.elementType(), o);
0188         }
0189       }
0190     } else if (t instanceof StructType) {
0191       StructType st = (StructType)t;
0192       if (src.isNullAt(fieldIdx)) {
0193         dst.appendStruct(true);
0194       } else {
0195         dst.appendStruct(false);
0196         Row c = src.getStruct(fieldIdx);
0197         for (int i = 0; i < st.fields().length; i++) {
0198           appendValue(dst.getChild(i), st.fields()[i].dataType(), c, i);
0199         }
0200       }
0201     } else {
0202       appendValue(dst, t, src.get(fieldIdx));
0203     }
0204   }
0205 
0206   /**
0207    * Converts an iterator of rows into a single ColumnBatch.
0208    */
0209   public static ColumnarBatch toBatch(
0210       StructType schema, MemoryMode memMode, Iterator<Row> row) {
0211     int capacity = 4 * 1024;
0212     WritableColumnVector[] columnVectors;
0213     if (memMode == MemoryMode.OFF_HEAP) {
0214       columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema);
0215     } else {
0216       columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema);
0217     }
0218 
0219     int n = 0;
0220     while (row.hasNext()) {
0221       Row r = row.next();
0222       for (int i = 0; i < schema.fields().length; i++) {
0223         appendValue(columnVectors[i], schema.fields()[i].dataType(), r, i);
0224       }
0225       n++;
0226     }
0227     ColumnarBatch batch = new ColumnarBatch(columnVectors);
0228     batch.setNumRows(n);
0229     return batch;
0230   }
0231 }