0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples;
0019
0020 import org.apache.spark.api.java.JavaRDD;
0021 import org.apache.spark.api.java.function.Function;
0022 import org.apache.spark.api.java.function.Function2;
0023 import org.apache.spark.sql.SparkSession;
0024
0025 import java.io.Serializable;
0026 import java.util.Arrays;
0027 import java.util.Random;
0028 import java.util.regex.Pattern;
0029
0030
0031
0032
0033
0034
0035
0036 public final class JavaHdfsLR {
0037
0038 private static final int D = 10;
0039 private static final Random rand = new Random(42);
0040
0041 static void showWarning() {
0042 String warning = "WARN: This is a naive implementation of Logistic Regression " +
0043 "and is given as an example!\n" +
0044 "Please use org.apache.spark.ml.classification.LogisticRegression " +
0045 "for more conventional use.";
0046 System.err.println(warning);
0047 }
0048
0049 static class DataPoint implements Serializable {
0050 DataPoint(double[] x, double y) {
0051 this.x = x;
0052 this.y = y;
0053 }
0054
0055 double[] x;
0056 double y;
0057 }
0058
0059 static class ParsePoint implements Function<String, DataPoint> {
0060 private static final Pattern SPACE = Pattern.compile(" ");
0061
0062 @Override
0063 public DataPoint call(String line) {
0064 String[] tok = SPACE.split(line);
0065 double y = Double.parseDouble(tok[0]);
0066 double[] x = new double[D];
0067 for (int i = 0; i < D; i++) {
0068 x[i] = Double.parseDouble(tok[i + 1]);
0069 }
0070 return new DataPoint(x, y);
0071 }
0072 }
0073
0074 static class VectorSum implements Function2<double[], double[], double[]> {
0075 @Override
0076 public double[] call(double[] a, double[] b) {
0077 double[] result = new double[D];
0078 for (int j = 0; j < D; j++) {
0079 result[j] = a[j] + b[j];
0080 }
0081 return result;
0082 }
0083 }
0084
0085 static class ComputeGradient implements Function<DataPoint, double[]> {
0086 private final double[] weights;
0087
0088 ComputeGradient(double[] weights) {
0089 this.weights = weights;
0090 }
0091
0092 @Override
0093 public double[] call(DataPoint p) {
0094 double[] gradient = new double[D];
0095 for (int i = 0; i < D; i++) {
0096 double dot = dot(weights, p.x);
0097 gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
0098 }
0099 return gradient;
0100 }
0101 }
0102
0103 public static double dot(double[] a, double[] b) {
0104 double x = 0;
0105 for (int i = 0; i < D; i++) {
0106 x += a[i] * b[i];
0107 }
0108 return x;
0109 }
0110
0111 public static void printWeights(double[] a) {
0112 System.out.println(Arrays.toString(a));
0113 }
0114
0115 public static void main(String[] args) {
0116
0117 if (args.length < 2) {
0118 System.err.println("Usage: JavaHdfsLR <file> <iters>");
0119 System.exit(1);
0120 }
0121
0122 showWarning();
0123
0124 SparkSession spark = SparkSession
0125 .builder()
0126 .appName("JavaHdfsLR")
0127 .getOrCreate();
0128
0129 JavaRDD<String> lines = spark.read().textFile(args[0]).javaRDD();
0130 JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
0131 int ITERATIONS = Integer.parseInt(args[1]);
0132
0133
0134 double[] w = new double[D];
0135 for (int i = 0; i < D; i++) {
0136 w[i] = 2 * rand.nextDouble() - 1;
0137 }
0138
0139 System.out.print("Initial w: ");
0140 printWeights(w);
0141
0142 for (int i = 1; i <= ITERATIONS; i++) {
0143 System.out.println("On iteration " + i);
0144
0145 double[] gradient = points.map(
0146 new ComputeGradient(w)
0147 ).reduce(new VectorSum());
0148
0149 for (int j = 0; j < D; j++) {
0150 w[j] -= gradient[j];
0151 }
0152
0153 }
0154
0155 System.out.print("Final w: ");
0156 printWeights(w);
0157 spark.stop();
0158 }
0159 }