0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.classification;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.ml.feature.LabeledPoint;
0028 import org.apache.spark.ml.linalg.Vectors;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031
0032 public class JavaMultilayerPerceptronClassifierSuite extends SharedSparkSession {
0033
0034 @Test
0035 public void testMLPC() {
0036 List<LabeledPoint> data = Arrays.asList(
0037 new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
0038 new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
0039 new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
0040 new LabeledPoint(0.0, Vectors.dense(1.0, 1.0))
0041 );
0042 Dataset<Row> dataFrame = spark.createDataFrame(data, LabeledPoint.class);
0043
0044 MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
0045 .setLayers(new int[]{2, 5, 2})
0046 .setBlockSize(1)
0047 .setSeed(123L)
0048 .setMaxIter(100);
0049 MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
0050 Dataset<Row> result = model.transform(dataFrame);
0051 List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
0052 for (Row r : predictionAndLabels) {
0053 Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
0054 }
0055 }
0056 }