Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import os
0019 import tempfile
0020 import unittest
0021 
0022 from pyspark.mllib.common import _to_java_object_rdd
0023 from pyspark.mllib.util import LinearDataGenerator
0024 from pyspark.mllib.util import MLUtils
0025 from pyspark.mllib.linalg import SparseVector, DenseVector, Vectors
0026 from pyspark.mllib.random import RandomRDDs
0027 from pyspark.testing.mllibutils import MLlibTestCase
0028 
0029 
0030 class MLUtilsTests(MLlibTestCase):
0031     def test_append_bias(self):
0032         data = [2.0, 2.0, 2.0]
0033         ret = MLUtils.appendBias(data)
0034         self.assertEqual(ret[3], 1.0)
0035         self.assertEqual(type(ret), DenseVector)
0036 
0037     def test_append_bias_with_vector(self):
0038         data = Vectors.dense([2.0, 2.0, 2.0])
0039         ret = MLUtils.appendBias(data)
0040         self.assertEqual(ret[3], 1.0)
0041         self.assertEqual(type(ret), DenseVector)
0042 
0043     def test_append_bias_with_sp_vector(self):
0044         data = Vectors.sparse(3, {0: 2.0, 2: 2.0})
0045         expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0})
0046         # Returned value must be SparseVector
0047         ret = MLUtils.appendBias(data)
0048         self.assertEqual(ret, expected)
0049         self.assertEqual(type(ret), SparseVector)
0050 
0051     def test_load_vectors(self):
0052         import shutil
0053         data = [
0054             [1.0, 2.0, 3.0],
0055             [1.0, 2.0, 3.0]
0056         ]
0057         temp_dir = tempfile.mkdtemp()
0058         load_vectors_path = os.path.join(temp_dir, "test_load_vectors")
0059         try:
0060             self.sc.parallelize(data).saveAsTextFile(load_vectors_path)
0061             ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path)
0062             ret = ret_rdd.collect()
0063             self.assertEqual(len(ret), 2)
0064             self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0]))
0065             self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0]))
0066         except:
0067             self.fail()
0068         finally:
0069             shutil.rmtree(load_vectors_path)
0070 
0071 
0072 class LinearDataGeneratorTests(MLlibTestCase):
0073     def test_dim(self):
0074         linear_data = LinearDataGenerator.generateLinearInput(
0075             intercept=0.0, weights=[0.0, 0.0, 0.0],
0076             xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
0077             nPoints=4, seed=0, eps=0.1)
0078         self.assertEqual(len(linear_data), 4)
0079         for point in linear_data:
0080             self.assertEqual(len(point.features), 3)
0081 
0082         linear_data = LinearDataGenerator.generateLinearRDD(
0083             sc=self.sc, nexamples=6, nfeatures=2, eps=0.1,
0084             nParts=2, intercept=0.0).collect()
0085         self.assertEqual(len(linear_data), 6)
0086         for point in linear_data:
0087             self.assertEqual(len(point.features), 2)
0088 
0089 
0090 class SerDeTest(MLlibTestCase):
0091     def test_to_java_object_rdd(self):  # SPARK-6660
0092         data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0)
0093         self.assertEqual(_to_java_object_rdd(data).count(), 10)
0094 
0095 
0096 if __name__ == "__main__":
0097     from pyspark.mllib.tests.test_util import *
0098 
0099     try:
0100         import xmlrunner
0101         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0102     except ImportError:
0103         testRunner = None
0104     unittest.main(testRunner=testRunner, verbosity=2)