0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import unittest
0019
0020 import py4j
0021
0022 from pyspark.ml.linalg import DenseVector, Vectors
0023 from pyspark.ml.regression import LinearRegression
0024 from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper
0025 from pyspark.testing.mllibutils import MLlibTestCase
0026 from pyspark.testing.mlutils import SparkSessionTestCase
0027 from pyspark.testing.utils import eventually
0028
0029
0030 class JavaWrapperMemoryTests(SparkSessionTestCase):
0031
0032 def test_java_object_gets_detached(self):
0033 df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
0034 (0.0, 2.0, Vectors.sparse(1, [], []))],
0035 ["label", "weight", "features"])
0036 lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight",
0037 fitIntercept=False)
0038
0039 model = lr.fit(df)
0040 summary = model.summary
0041
0042 self.assertIsInstance(model, JavaWrapper)
0043 self.assertIsInstance(summary, JavaWrapper)
0044 self.assertIsInstance(model, JavaParams)
0045 self.assertNotIsInstance(summary, JavaParams)
0046
0047 error_no_object = 'Target Object ID does not exist for this gateway'
0048
0049 self.assertIn("LinearRegression_", model._java_obj.toString())
0050 self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
0051
0052 model.__del__()
0053
0054 def condition():
0055 with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
0056 model._java_obj.toString()
0057 self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
0058 return True
0059
0060 eventually(condition, timeout=10, catch_assertions=True)
0061
0062 try:
0063 summary.__del__()
0064 except:
0065 pass
0066
0067 def condition():
0068 with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
0069 model._java_obj.toString()
0070 with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
0071 summary._java_obj.toString()
0072 return True
0073
0074 eventually(condition, timeout=10, catch_assertions=True)
0075
0076
0077 class WrapperTests(MLlibTestCase):
0078
0079 def test_new_java_array(self):
0080
0081 str_list = ["a", "b", "c"]
0082 java_class = self.sc._gateway.jvm.java.lang.String
0083 java_array = JavaWrapper._new_java_array(str_list, java_class)
0084 self.assertEqual(_java2py(self.sc, java_array), str_list)
0085
0086 int_list = [1, 2, 3]
0087 java_class = self.sc._gateway.jvm.java.lang.Integer
0088 java_array = JavaWrapper._new_java_array(int_list, java_class)
0089 self.assertEqual(_java2py(self.sc, java_array), int_list)
0090
0091 float_list = [0.1, 0.2, 0.3]
0092 java_class = self.sc._gateway.jvm.java.lang.Double
0093 java_array = JavaWrapper._new_java_array(float_list, java_class)
0094 self.assertEqual(_java2py(self.sc, java_array), float_list)
0095
0096 bool_list = [False, True, True]
0097 java_class = self.sc._gateway.jvm.java.lang.Boolean
0098 java_array = JavaWrapper._new_java_array(bool_list, java_class)
0099 self.assertEqual(_java2py(self.sc, java_array), bool_list)
0100
0101 v1 = DenseVector([0.0, 1.0])
0102 v2 = DenseVector([1.0, 0.0])
0103 vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)]
0104 java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector
0105 java_array = JavaWrapper._new_java_array(vec_java_list, java_class)
0106 self.assertEqual(_java2py(self.sc, java_array), [v1, v2])
0107
0108 java_class = self.sc._gateway.jvm.java.lang.Integer
0109 java_array = JavaWrapper._new_java_array([], java_class)
0110 self.assertEqual(_java2py(self.sc, java_array), [])
0111
0112 str_list = [["a", "b", "c"], ["d", "e"], ["f", "g", "h", "i"], []]
0113 expected_str_list = [("a", "b", "c", None), ("d", "e", None, None), ("f", "g", "h", "i"),
0114 (None, None, None, None)]
0115 java_class = self.sc._gateway.jvm.java.lang.String
0116 java_array = JavaWrapper._new_java_array(str_list, java_class)
0117 self.assertEqual(_java2py(self.sc, java_array), expected_str_list)
0118
0119 if __name__ == "__main__":
0120 from pyspark.ml.tests.test_wrapper import *
0121
0122 try:
0123 import xmlrunner
0124 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0125 except ImportError:
0126 testRunner = None
0127 unittest.main(testRunner=testRunner, verbosity=2)