Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.examples.ml;
0019 
0020 // $example on$
0021 import java.util.Arrays;
0022 
0023 import org.apache.spark.ml.Pipeline;
0024 import org.apache.spark.ml.PipelineModel;
0025 import org.apache.spark.ml.PipelineStage;
0026 import org.apache.spark.ml.classification.LogisticRegression;
0027 import org.apache.spark.ml.feature.HashingTF;
0028 import org.apache.spark.ml.feature.Tokenizer;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031 // $example off$
0032 import org.apache.spark.sql.SparkSession;
0033 
0034 /**
0035  * Java example for simple text document 'Pipeline'.
0036  */
0037 public class JavaPipelineExample {
0038   public static void main(String[] args) {
0039     SparkSession spark = SparkSession
0040       .builder()
0041       .appName("JavaPipelineExample")
0042       .getOrCreate();
0043 
0044     // $example on$
0045     // Prepare training documents, which are labeled.
0046     Dataset<Row> training = spark.createDataFrame(Arrays.asList(
0047       new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
0048       new JavaLabeledDocument(1L, "b d", 0.0),
0049       new JavaLabeledDocument(2L, "spark f g h", 1.0),
0050       new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0)
0051     ), JavaLabeledDocument.class);
0052 
0053     // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
0054     Tokenizer tokenizer = new Tokenizer()
0055       .setInputCol("text")
0056       .setOutputCol("words");
0057     HashingTF hashingTF = new HashingTF()
0058       .setNumFeatures(1000)
0059       .setInputCol(tokenizer.getOutputCol())
0060       .setOutputCol("features");
0061     LogisticRegression lr = new LogisticRegression()
0062       .setMaxIter(10)
0063       .setRegParam(0.001);
0064     Pipeline pipeline = new Pipeline()
0065       .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
0066 
0067     // Fit the pipeline to training documents.
0068     PipelineModel model = pipeline.fit(training);
0069 
0070     // Prepare test documents, which are unlabeled.
0071     Dataset<Row> test = spark.createDataFrame(Arrays.asList(
0072       new JavaDocument(4L, "spark i j k"),
0073       new JavaDocument(5L, "l m n"),
0074       new JavaDocument(6L, "spark hadoop spark"),
0075       new JavaDocument(7L, "apache hadoop")
0076     ), JavaDocument.class);
0077 
0078     // Make predictions on test documents.
0079     Dataset<Row> predictions = model.transform(test);
0080     for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) {
0081       System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
0082         + ", prediction=" + r.get(3));
0083     }
0084     // $example off$
0085 
0086     spark.stop();
0087   }
0088 }