Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package test.org.apache.spark.sql.connector;
0019 
0020 import java.io.IOException;
0021 import java.util.*;
0022 
0023 import org.apache.spark.sql.catalyst.InternalRow;
0024 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
0025 import org.apache.spark.sql.connector.TestingV2Source;
0026 import org.apache.spark.sql.connector.catalog.Table;
0027 import org.apache.spark.sql.connector.read.*;
0028 import org.apache.spark.sql.sources.Filter;
0029 import org.apache.spark.sql.sources.GreaterThan;
0030 import org.apache.spark.sql.types.StructType;
0031 import org.apache.spark.sql.util.CaseInsensitiveStringMap;
0032 
0033 public class JavaAdvancedDataSourceV2 implements TestingV2Source {
0034 
0035   @Override
0036   public Table getTable(CaseInsensitiveStringMap options) {
0037     return new JavaSimpleBatchTable() {
0038       @Override
0039       public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
0040         return new AdvancedScanBuilder();
0041       }
0042     };
0043   }
0044 
0045   static class AdvancedScanBuilder implements ScanBuilder, Scan,
0046     SupportsPushDownFilters, SupportsPushDownRequiredColumns {
0047 
0048     private StructType requiredSchema = TestingV2Source.schema();
0049     private Filter[] filters = new Filter[0];
0050 
0051     @Override
0052     public void pruneColumns(StructType requiredSchema) {
0053       this.requiredSchema = requiredSchema;
0054     }
0055 
0056     @Override
0057     public StructType readSchema() {
0058       return requiredSchema;
0059     }
0060 
0061     @Override
0062     public Filter[] pushFilters(Filter[] filters) {
0063       Filter[] supported = Arrays.stream(filters).filter(f -> {
0064         if (f instanceof GreaterThan) {
0065           GreaterThan gt = (GreaterThan) f;
0066           return gt.attribute().equals("i") && gt.value() instanceof Integer;
0067         } else {
0068           return false;
0069         }
0070       }).toArray(Filter[]::new);
0071 
0072       Filter[] unsupported = Arrays.stream(filters).filter(f -> {
0073         if (f instanceof GreaterThan) {
0074           GreaterThan gt = (GreaterThan) f;
0075           return !gt.attribute().equals("i") || !(gt.value() instanceof Integer);
0076         } else {
0077           return true;
0078         }
0079       }).toArray(Filter[]::new);
0080 
0081       this.filters = supported;
0082       return unsupported;
0083     }
0084 
0085     @Override
0086     public Filter[] pushedFilters() {
0087       return filters;
0088     }
0089 
0090     @Override
0091     public Scan build() {
0092       return this;
0093     }
0094 
0095     @Override
0096     public Batch toBatch() {
0097       return new AdvancedBatch(requiredSchema, filters);
0098     }
0099   }
0100 
0101   public static class AdvancedBatch implements Batch {
0102     // Exposed for testing.
0103     public StructType requiredSchema;
0104     public Filter[] filters;
0105 
0106     AdvancedBatch(StructType requiredSchema, Filter[] filters) {
0107       this.requiredSchema = requiredSchema;
0108       this.filters = filters;
0109     }
0110 
0111     @Override
0112     public InputPartition[] planInputPartitions() {
0113       List<InputPartition> res = new ArrayList<>();
0114 
0115       Integer lowerBound = null;
0116       for (Filter filter : filters) {
0117         if (filter instanceof GreaterThan) {
0118           GreaterThan f = (GreaterThan) filter;
0119           if ("i".equals(f.attribute()) && f.value() instanceof Integer) {
0120             lowerBound = (Integer) f.value();
0121             break;
0122           }
0123         }
0124       }
0125 
0126       if (lowerBound == null) {
0127         res.add(new JavaRangeInputPartition(0, 5));
0128         res.add(new JavaRangeInputPartition(5, 10));
0129       } else if (lowerBound < 4) {
0130         res.add(new JavaRangeInputPartition(lowerBound + 1, 5));
0131         res.add(new JavaRangeInputPartition(5, 10));
0132       } else if (lowerBound < 9) {
0133         res.add(new JavaRangeInputPartition(lowerBound + 1, 10));
0134       }
0135 
0136       return res.stream().toArray(InputPartition[]::new);
0137     }
0138 
0139     @Override
0140     public PartitionReaderFactory createReaderFactory() {
0141       return new AdvancedReaderFactory(requiredSchema);
0142     }
0143   }
0144 
0145   static class AdvancedReaderFactory implements PartitionReaderFactory {
0146     StructType requiredSchema;
0147 
0148     AdvancedReaderFactory(StructType requiredSchema) {
0149       this.requiredSchema = requiredSchema;
0150     }
0151 
0152     @Override
0153     public PartitionReader<InternalRow> createReader(InputPartition partition) {
0154       JavaRangeInputPartition p = (JavaRangeInputPartition) partition;
0155       return new PartitionReader<InternalRow>() {
0156         private int current = p.start - 1;
0157 
0158         @Override
0159         public boolean next() throws IOException {
0160           current += 1;
0161           return current < p.end;
0162         }
0163 
0164         @Override
0165         public InternalRow get() {
0166           Object[] values = new Object[requiredSchema.size()];
0167           for (int i = 0; i < values.length; i++) {
0168             if ("i".equals(requiredSchema.apply(i).name())) {
0169               values[i] = current;
0170             } else if ("j".equals(requiredSchema.apply(i).name())) {
0171               values[i] = -current;
0172             }
0173           }
0174           return new GenericInternalRow(values);
0175         }
0176 
0177         @Override
0178         public void close() throws IOException {
0179 
0180         }
0181       };
0182     }
0183   }
0184 }