0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package test.org.apache.spark.sql.connector;
0019
0020 import java.io.IOException;
0021 import java.util.*;
0022
0023 import org.apache.spark.sql.catalyst.InternalRow;
0024 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
0025 import org.apache.spark.sql.connector.TestingV2Source;
0026 import org.apache.spark.sql.connector.catalog.Table;
0027 import org.apache.spark.sql.connector.read.*;
0028 import org.apache.spark.sql.sources.Filter;
0029 import org.apache.spark.sql.sources.GreaterThan;
0030 import org.apache.spark.sql.types.StructType;
0031 import org.apache.spark.sql.util.CaseInsensitiveStringMap;
0032
0033 public class JavaAdvancedDataSourceV2 implements TestingV2Source {
0034
0035 @Override
0036 public Table getTable(CaseInsensitiveStringMap options) {
0037 return new JavaSimpleBatchTable() {
0038 @Override
0039 public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
0040 return new AdvancedScanBuilder();
0041 }
0042 };
0043 }
0044
0045 static class AdvancedScanBuilder implements ScanBuilder, Scan,
0046 SupportsPushDownFilters, SupportsPushDownRequiredColumns {
0047
0048 private StructType requiredSchema = TestingV2Source.schema();
0049 private Filter[] filters = new Filter[0];
0050
0051 @Override
0052 public void pruneColumns(StructType requiredSchema) {
0053 this.requiredSchema = requiredSchema;
0054 }
0055
0056 @Override
0057 public StructType readSchema() {
0058 return requiredSchema;
0059 }
0060
0061 @Override
0062 public Filter[] pushFilters(Filter[] filters) {
0063 Filter[] supported = Arrays.stream(filters).filter(f -> {
0064 if (f instanceof GreaterThan) {
0065 GreaterThan gt = (GreaterThan) f;
0066 return gt.attribute().equals("i") && gt.value() instanceof Integer;
0067 } else {
0068 return false;
0069 }
0070 }).toArray(Filter[]::new);
0071
0072 Filter[] unsupported = Arrays.stream(filters).filter(f -> {
0073 if (f instanceof GreaterThan) {
0074 GreaterThan gt = (GreaterThan) f;
0075 return !gt.attribute().equals("i") || !(gt.value() instanceof Integer);
0076 } else {
0077 return true;
0078 }
0079 }).toArray(Filter[]::new);
0080
0081 this.filters = supported;
0082 return unsupported;
0083 }
0084
0085 @Override
0086 public Filter[] pushedFilters() {
0087 return filters;
0088 }
0089
0090 @Override
0091 public Scan build() {
0092 return this;
0093 }
0094
0095 @Override
0096 public Batch toBatch() {
0097 return new AdvancedBatch(requiredSchema, filters);
0098 }
0099 }
0100
0101 public static class AdvancedBatch implements Batch {
0102
0103 public StructType requiredSchema;
0104 public Filter[] filters;
0105
0106 AdvancedBatch(StructType requiredSchema, Filter[] filters) {
0107 this.requiredSchema = requiredSchema;
0108 this.filters = filters;
0109 }
0110
0111 @Override
0112 public InputPartition[] planInputPartitions() {
0113 List<InputPartition> res = new ArrayList<>();
0114
0115 Integer lowerBound = null;
0116 for (Filter filter : filters) {
0117 if (filter instanceof GreaterThan) {
0118 GreaterThan f = (GreaterThan) filter;
0119 if ("i".equals(f.attribute()) && f.value() instanceof Integer) {
0120 lowerBound = (Integer) f.value();
0121 break;
0122 }
0123 }
0124 }
0125
0126 if (lowerBound == null) {
0127 res.add(new JavaRangeInputPartition(0, 5));
0128 res.add(new JavaRangeInputPartition(5, 10));
0129 } else if (lowerBound < 4) {
0130 res.add(new JavaRangeInputPartition(lowerBound + 1, 5));
0131 res.add(new JavaRangeInputPartition(5, 10));
0132 } else if (lowerBound < 9) {
0133 res.add(new JavaRangeInputPartition(lowerBound + 1, 10));
0134 }
0135
0136 return res.stream().toArray(InputPartition[]::new);
0137 }
0138
0139 @Override
0140 public PartitionReaderFactory createReaderFactory() {
0141 return new AdvancedReaderFactory(requiredSchema);
0142 }
0143 }
0144
0145 static class AdvancedReaderFactory implements PartitionReaderFactory {
0146 StructType requiredSchema;
0147
0148 AdvancedReaderFactory(StructType requiredSchema) {
0149 this.requiredSchema = requiredSchema;
0150 }
0151
0152 @Override
0153 public PartitionReader<InternalRow> createReader(InputPartition partition) {
0154 JavaRangeInputPartition p = (JavaRangeInputPartition) partition;
0155 return new PartitionReader<InternalRow>() {
0156 private int current = p.start - 1;
0157
0158 @Override
0159 public boolean next() throws IOException {
0160 current += 1;
0161 return current < p.end;
0162 }
0163
0164 @Override
0165 public InternalRow get() {
0166 Object[] values = new Object[requiredSchema.size()];
0167 for (int i = 0; i < values.length; i++) {
0168 if ("i".equals(requiredSchema.apply(i).name())) {
0169 values[i] = current;
0170 } else if ("j".equals(requiredSchema.apply(i).name())) {
0171 values[i] = -current;
0172 }
0173 }
0174 return new GenericInternalRow(values);
0175 }
0176
0177 @Override
0178 public void close() throws IOException {
0179
0180 }
0181 };
0182 }
0183 }
0184 }