0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.streaming;
0019
0020 import java.util.ArrayList;
0021 import java.nio.ByteBuffer;
0022 import java.util.Arrays;
0023 import java.util.Iterator;
0024 import java.util.List;
0025
0026 import com.google.common.collect.Iterators;
0027 import org.apache.spark.SparkConf;
0028 import org.apache.spark.network.util.JavaUtils;
0029 import org.apache.spark.streaming.util.WriteAheadLog;
0030 import org.apache.spark.streaming.util.WriteAheadLogRecordHandle;
0031 import org.apache.spark.streaming.util.WriteAheadLogUtils;
0032
0033 import org.junit.Test;
0034 import org.junit.Assert;
0035
0036 public class JavaWriteAheadLogSuite extends WriteAheadLog {
0037
0038 static class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle {
0039 int index = -1;
0040 JavaWriteAheadLogSuiteHandle(int idx) {
0041 index = idx;
0042 }
0043 }
0044
0045 static class Record {
0046 long time;
0047 int index;
0048 ByteBuffer buffer;
0049
0050 Record(long tym, int idx, ByteBuffer buf) {
0051 index = idx;
0052 time = tym;
0053 buffer = buf;
0054 }
0055 }
0056 private int index = -1;
0057 private final List<Record> records = new ArrayList<>();
0058
0059
0060
0061 @Override
0062 public WriteAheadLogRecordHandle write(ByteBuffer record, long time) {
0063 index += 1;
0064 records.add(new Record(time, index, record));
0065 return new JavaWriteAheadLogSuiteHandle(index);
0066 }
0067
0068 @Override
0069 public ByteBuffer read(WriteAheadLogRecordHandle handle) {
0070 if (handle instanceof JavaWriteAheadLogSuiteHandle) {
0071 int reqdIndex = ((JavaWriteAheadLogSuiteHandle) handle).index;
0072 for (Record record: records) {
0073 if (record.index == reqdIndex) {
0074 return record.buffer;
0075 }
0076 }
0077 }
0078 return null;
0079 }
0080
0081 @Override
0082 public Iterator<ByteBuffer> readAll() {
0083 return Iterators.transform(records.iterator(), input -> input.buffer);
0084 }
0085
0086 @Override
0087 public void clean(long threshTime, boolean waitForCompletion) {
0088 for (int i = 0; i < records.size(); i++) {
0089 if (records.get(i).time < threshTime) {
0090 records.remove(i);
0091 i--;
0092 }
0093 }
0094 }
0095
0096 @Override
0097 public void close() {
0098 records.clear();
0099 }
0100
0101 @Test
0102 public void testCustomWAL() {
0103 SparkConf conf = new SparkConf();
0104 conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName());
0105 conf.set("spark.streaming.driver.writeAheadLog.allowBatching", "false");
0106 WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null);
0107
0108 String data1 = "data1";
0109 WriteAheadLogRecordHandle handle = wal.write(JavaUtils.stringToBytes(data1), 1234);
0110 Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle);
0111 Assert.assertEquals(data1, JavaUtils.bytesToString(wal.read(handle)));
0112
0113 wal.write(JavaUtils.stringToBytes("data2"), 1235);
0114 wal.write(JavaUtils.stringToBytes("data3"), 1236);
0115 wal.write(JavaUtils.stringToBytes("data4"), 1237);
0116 wal.clean(1236, false);
0117
0118 Iterator<ByteBuffer> dataIterator = wal.readAll();
0119 List<String> readData = new ArrayList<>();
0120 while (dataIterator.hasNext()) {
0121 readData.add(JavaUtils.bytesToString(dataIterator.next()));
0122 }
0123 Assert.assertEquals(Arrays.asList("data3", "data4"), readData);
0124 }
0125 }