Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.streaming;
0019 
0020 import java.io.Serializable;
0021 import java.util.ArrayList;
0022 import java.util.Arrays;
0023 import java.util.Collections;
0024 import java.util.List;
0025 import java.util.Set;
0026 
0027 import scala.Tuple2;
0028 
0029 import com.google.common.collect.Sets;
0030 import org.apache.spark.streaming.api.java.JavaDStream;
0031 import org.apache.spark.util.ManualClock;
0032 import org.junit.Assert;
0033 import org.junit.Test;
0034 
0035 import org.apache.spark.HashPartitioner;
0036 import org.apache.spark.api.java.JavaPairRDD;
0037 import org.apache.spark.api.java.Optional;
0038 import org.apache.spark.api.java.function.Function3;
0039 import org.apache.spark.api.java.function.Function4;
0040 import org.apache.spark.streaming.api.java.JavaPairDStream;
0041 import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
0042 
0043 public class JavaMapWithStateSuite extends LocalJavaStreamingContext implements Serializable {
0044 
0045   /**
0046    * This test is only for testing the APIs. It's not necessary to run it.
0047    */
0048   public void testAPI() {
0049     JavaPairRDD<String, Boolean> initialRDD = null;
0050     JavaPairDStream<String, Integer> wordsDstream = null;
0051 
0052     Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>> mappingFunc =
0053         (time, word, one, state) -> {
0054           // Use all State's methods here
0055           state.exists();
0056           state.get();
0057           state.isTimingOut();
0058           state.remove();
0059           state.update(true);
0060           return Optional.of(2.0);
0061         };
0062 
0063     JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
0064         wordsDstream.mapWithState(
0065             StateSpec.function(mappingFunc)
0066                 .initialState(initialRDD)
0067                 .numPartitions(10)
0068                 .partitioner(new HashPartitioner(10))
0069                 .timeout(Durations.seconds(10)));
0070 
0071     stateDstream.stateSnapshots();
0072 
0073     Function3<String, Optional<Integer>, State<Boolean>, Double> mappingFunc2 =
0074         (key, one, state) -> {
0075           // Use all State's methods here
0076           state.exists();
0077           state.get();
0078           state.isTimingOut();
0079           state.remove();
0080           state.update(true);
0081           return 2.0;
0082         };
0083 
0084     JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
0085         wordsDstream.mapWithState(
0086             StateSpec.function(mappingFunc2)
0087                 .initialState(initialRDD)
0088                 .numPartitions(10)
0089                 .partitioner(new HashPartitioner(10))
0090                 .timeout(Durations.seconds(10)));
0091 
0092     stateDstream2.stateSnapshots();
0093   }
0094 
0095   @Test
0096   public void testBasicFunction() {
0097     List<List<String>> inputData = Arrays.asList(
0098         Collections.<String>emptyList(),
0099         Arrays.asList("a"),
0100         Arrays.asList("a", "b"),
0101         Arrays.asList("a", "b", "c"),
0102         Arrays.asList("a", "b"),
0103         Arrays.asList("a"),
0104         Collections.<String>emptyList()
0105     );
0106 
0107     List<Set<Integer>> outputData = Arrays.asList(
0108         Collections.<Integer>emptySet(),
0109         Sets.newHashSet(1),
0110         Sets.newHashSet(2, 1),
0111         Sets.newHashSet(3, 2, 1),
0112         Sets.newHashSet(4, 3),
0113         Sets.newHashSet(5),
0114         Collections.<Integer>emptySet()
0115     );
0116 
0117     @SuppressWarnings("unchecked")
0118     List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList(
0119         Collections.<Tuple2<String, Integer>>emptySet(),
0120         Sets.newHashSet(new Tuple2<>("a", 1)),
0121         Sets.newHashSet(new Tuple2<>("a", 2), new Tuple2<>("b", 1)),
0122         Sets.newHashSet(new Tuple2<>("a", 3), new Tuple2<>("b", 2), new Tuple2<>("c", 1)),
0123         Sets.newHashSet(new Tuple2<>("a", 4), new Tuple2<>("b", 3), new Tuple2<>("c", 1)),
0124         Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1)),
0125         Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1))
0126     );
0127 
0128     Function3<String, Optional<Integer>, State<Integer>, Integer> mappingFunc =
0129         (key, value, state) -> {
0130           int sum = value.orElse(0) + (state.exists() ? state.get() : 0);
0131           state.update(sum);
0132           return sum;
0133         };
0134     testOperation(
0135         inputData,
0136         StateSpec.function(mappingFunc),
0137         outputData,
0138         stateData);
0139   }
0140 
0141   private <K, S, T> void testOperation(
0142       List<List<K>> input,
0143       StateSpec<K, Integer, S, T> mapWithStateSpec,
0144       List<Set<T>> expectedOutputs,
0145       List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
0146     int numBatches = expectedOutputs.size();
0147     JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
0148     JavaMapWithStateDStream<K, Integer, S, T> mapWithStateDStream = JavaPairDStream.fromJavaDStream(
0149       inputStream.map(x -> new Tuple2<>(x, 1))).mapWithState(mapWithStateSpec);
0150 
0151     List<Set<T>> collectedOutputs =
0152         Collections.synchronizedList(new ArrayList<>());
0153     mapWithStateDStream.foreachRDD(rdd -> collectedOutputs.add(Sets.newHashSet(rdd.collect())));
0154     List<Set<Tuple2<K, S>>> collectedStateSnapshots =
0155         Collections.synchronizedList(new ArrayList<>());
0156     mapWithStateDStream.stateSnapshots().foreachRDD(rdd ->
0157         collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())));
0158     BatchCounter batchCounter = new BatchCounter(ssc.ssc());
0159     ssc.start();
0160     ((ManualClock) ssc.ssc().scheduler().clock())
0161         .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1);
0162     batchCounter.waitUntilBatchesCompleted(numBatches, 10000);
0163 
0164     Assert.assertEquals(expectedOutputs, collectedOutputs);
0165     Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots);
0166   }
0167 }