Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples;
0019 
0020 import org.apache.spark.api.java.JavaRDD;
0021 import org.apache.spark.api.java.function.Function;
0022 import org.apache.spark.api.java.function.Function2;
0023 import org.apache.spark.sql.SparkSession;
0024 
0025 import java.io.Serializable;
0026 import java.util.Arrays;
0027 import java.util.Random;
0028 import java.util.regex.Pattern;
0029 
0030 /**
0031  * Logistic regression based classification.
0032  *
0033  * This is an example implementation for learning how to use Spark. For more conventional use,
0034  * please refer to org.apache.spark.ml.classification.LogisticRegression.
0035  */
0036 public final class JavaHdfsLR {
0037 
0038   private static final int D = 10;   // Number of dimensions
0039   private static final Random rand = new Random(42);
0040 
0041   static void showWarning() {
0042     String warning = "WARN: This is a naive implementation of Logistic Regression " +
0043             "and is given as an example!\n" +
0044             "Please use org.apache.spark.ml.classification.LogisticRegression " +
0045             "for more conventional use.";
0046     System.err.println(warning);
0047   }
0048 
0049   static class DataPoint implements Serializable {
0050     DataPoint(double[] x, double y) {
0051       this.x = x;
0052       this.y = y;
0053     }
0054 
0055     double[] x;
0056     double y;
0057   }
0058 
0059   static class ParsePoint implements Function<String, DataPoint> {
0060     private static final Pattern SPACE = Pattern.compile(" ");
0061 
0062     @Override
0063     public DataPoint call(String line) {
0064       String[] tok = SPACE.split(line);
0065       double y = Double.parseDouble(tok[0]);
0066       double[] x = new double[D];
0067       for (int i = 0; i < D; i++) {
0068         x[i] = Double.parseDouble(tok[i + 1]);
0069       }
0070       return new DataPoint(x, y);
0071     }
0072   }
0073 
0074   static class VectorSum implements Function2<double[], double[], double[]> {
0075     @Override
0076     public double[] call(double[] a, double[] b) {
0077       double[] result = new double[D];
0078       for (int j = 0; j < D; j++) {
0079         result[j] = a[j] + b[j];
0080       }
0081       return result;
0082     }
0083   }
0084 
0085   static class ComputeGradient implements Function<DataPoint, double[]> {
0086     private final double[] weights;
0087 
0088     ComputeGradient(double[] weights) {
0089       this.weights = weights;
0090     }
0091 
0092     @Override
0093     public double[] call(DataPoint p) {
0094       double[] gradient = new double[D];
0095       for (int i = 0; i < D; i++) {
0096         double dot = dot(weights, p.x);
0097         gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
0098       }
0099       return gradient;
0100     }
0101   }
0102 
0103   public static double dot(double[] a, double[] b) {
0104     double x = 0;
0105     for (int i = 0; i < D; i++) {
0106       x += a[i] * b[i];
0107     }
0108     return x;
0109   }
0110 
0111   public static void printWeights(double[] a) {
0112     System.out.println(Arrays.toString(a));
0113   }
0114 
0115   public static void main(String[] args) {
0116 
0117     if (args.length < 2) {
0118       System.err.println("Usage: JavaHdfsLR <file> <iters>");
0119       System.exit(1);
0120     }
0121 
0122     showWarning();
0123 
0124     SparkSession spark = SparkSession
0125       .builder()
0126       .appName("JavaHdfsLR")
0127       .getOrCreate();
0128 
0129     JavaRDD<String> lines = spark.read().textFile(args[0]).javaRDD();
0130     JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
0131     int ITERATIONS = Integer.parseInt(args[1]);
0132 
0133     // Initialize w to a random value
0134     double[] w = new double[D];
0135     for (int i = 0; i < D; i++) {
0136       w[i] = 2 * rand.nextDouble() - 1;
0137     }
0138 
0139     System.out.print("Initial w: ");
0140     printWeights(w);
0141 
0142     for (int i = 1; i <= ITERATIONS; i++) {
0143       System.out.println("On iteration " + i);
0144 
0145       double[] gradient = points.map(
0146         new ComputeGradient(w)
0147       ).reduce(new VectorSum());
0148 
0149       for (int j = 0; j < D; j++) {
0150         w[j] -= gradient[j];
0151       }
0152 
0153     }
0154 
0155     System.out.print("Final w: ");
0156     printWeights(w);
0157     spark.stop();
0158   }
0159 }