0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.streaming;
0019
0020 import java.io.Serializable;
0021 import java.util.ArrayList;
0022 import java.util.Arrays;
0023 import java.util.Collections;
0024 import java.util.List;
0025 import java.util.Set;
0026
0027 import scala.Tuple2;
0028
0029 import com.google.common.collect.Sets;
0030 import org.apache.spark.streaming.api.java.JavaDStream;
0031 import org.apache.spark.util.ManualClock;
0032 import org.junit.Assert;
0033 import org.junit.Test;
0034
0035 import org.apache.spark.HashPartitioner;
0036 import org.apache.spark.api.java.JavaPairRDD;
0037 import org.apache.spark.api.java.Optional;
0038 import org.apache.spark.api.java.function.Function3;
0039 import org.apache.spark.api.java.function.Function4;
0040 import org.apache.spark.streaming.api.java.JavaPairDStream;
0041 import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
0042
0043 public class JavaMapWithStateSuite extends LocalJavaStreamingContext implements Serializable {
0044
0045
0046
0047
0048 public void testAPI() {
0049 JavaPairRDD<String, Boolean> initialRDD = null;
0050 JavaPairDStream<String, Integer> wordsDstream = null;
0051
0052 Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>> mappingFunc =
0053 (time, word, one, state) -> {
0054
0055 state.exists();
0056 state.get();
0057 state.isTimingOut();
0058 state.remove();
0059 state.update(true);
0060 return Optional.of(2.0);
0061 };
0062
0063 JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
0064 wordsDstream.mapWithState(
0065 StateSpec.function(mappingFunc)
0066 .initialState(initialRDD)
0067 .numPartitions(10)
0068 .partitioner(new HashPartitioner(10))
0069 .timeout(Durations.seconds(10)));
0070
0071 stateDstream.stateSnapshots();
0072
0073 Function3<String, Optional<Integer>, State<Boolean>, Double> mappingFunc2 =
0074 (key, one, state) -> {
0075
0076 state.exists();
0077 state.get();
0078 state.isTimingOut();
0079 state.remove();
0080 state.update(true);
0081 return 2.0;
0082 };
0083
0084 JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
0085 wordsDstream.mapWithState(
0086 StateSpec.function(mappingFunc2)
0087 .initialState(initialRDD)
0088 .numPartitions(10)
0089 .partitioner(new HashPartitioner(10))
0090 .timeout(Durations.seconds(10)));
0091
0092 stateDstream2.stateSnapshots();
0093 }
0094
0095 @Test
0096 public void testBasicFunction() {
0097 List<List<String>> inputData = Arrays.asList(
0098 Collections.<String>emptyList(),
0099 Arrays.asList("a"),
0100 Arrays.asList("a", "b"),
0101 Arrays.asList("a", "b", "c"),
0102 Arrays.asList("a", "b"),
0103 Arrays.asList("a"),
0104 Collections.<String>emptyList()
0105 );
0106
0107 List<Set<Integer>> outputData = Arrays.asList(
0108 Collections.<Integer>emptySet(),
0109 Sets.newHashSet(1),
0110 Sets.newHashSet(2, 1),
0111 Sets.newHashSet(3, 2, 1),
0112 Sets.newHashSet(4, 3),
0113 Sets.newHashSet(5),
0114 Collections.<Integer>emptySet()
0115 );
0116
0117 @SuppressWarnings("unchecked")
0118 List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList(
0119 Collections.<Tuple2<String, Integer>>emptySet(),
0120 Sets.newHashSet(new Tuple2<>("a", 1)),
0121 Sets.newHashSet(new Tuple2<>("a", 2), new Tuple2<>("b", 1)),
0122 Sets.newHashSet(new Tuple2<>("a", 3), new Tuple2<>("b", 2), new Tuple2<>("c", 1)),
0123 Sets.newHashSet(new Tuple2<>("a", 4), new Tuple2<>("b", 3), new Tuple2<>("c", 1)),
0124 Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1)),
0125 Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1))
0126 );
0127
0128 Function3<String, Optional<Integer>, State<Integer>, Integer> mappingFunc =
0129 (key, value, state) -> {
0130 int sum = value.orElse(0) + (state.exists() ? state.get() : 0);
0131 state.update(sum);
0132 return sum;
0133 };
0134 testOperation(
0135 inputData,
0136 StateSpec.function(mappingFunc),
0137 outputData,
0138 stateData);
0139 }
0140
0141 private <K, S, T> void testOperation(
0142 List<List<K>> input,
0143 StateSpec<K, Integer, S, T> mapWithStateSpec,
0144 List<Set<T>> expectedOutputs,
0145 List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
0146 int numBatches = expectedOutputs.size();
0147 JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
0148 JavaMapWithStateDStream<K, Integer, S, T> mapWithStateDStream = JavaPairDStream.fromJavaDStream(
0149 inputStream.map(x -> new Tuple2<>(x, 1))).mapWithState(mapWithStateSpec);
0150
0151 List<Set<T>> collectedOutputs =
0152 Collections.synchronizedList(new ArrayList<>());
0153 mapWithStateDStream.foreachRDD(rdd -> collectedOutputs.add(Sets.newHashSet(rdd.collect())));
0154 List<Set<Tuple2<K, S>>> collectedStateSnapshots =
0155 Collections.synchronizedList(new ArrayList<>());
0156 mapWithStateDStream.stateSnapshots().foreachRDD(rdd ->
0157 collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())));
0158 BatchCounter batchCounter = new BatchCounter(ssc.ssc());
0159 ssc.start();
0160 ((ManualClock) ssc.ssc().scheduler().clock())
0161 .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1);
0162 batchCounter.waitUntilBatchesCompleted(numBatches, 10000);
0163
0164 Assert.assertEquals(expectedOutputs, collectedOutputs);
0165 Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots);
0166 }
0167 }