0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import numpy as np
0019
0020 from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer
0021 from pyspark.ml.param import Param, Params, TypeConverters
0022 from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
0023 from pyspark.ml.wrapper import _java2py
0024 from pyspark.sql import DataFrame, SparkSession
0025 from pyspark.sql.types import DoubleType
0026 from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase
0027
0028
0029 def check_params(test_self, py_stage, check_params_exist=True):
0030 """
0031 Checks common requirements for Params.params:
0032 - set of params exist in Java and Python and are ordered by names
0033 - param parent has the same UID as the object's UID
0034 - default param value from Java matches value in Python
0035 - optionally check if all params from Java also exist in Python
0036 """
0037 py_stage_str = "%s %s" % (type(py_stage), py_stage)
0038 if not hasattr(py_stage, "_to_java"):
0039 return
0040 java_stage = py_stage._to_java()
0041 if java_stage is None:
0042 return
0043 test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str)
0044 if check_params_exist:
0045 param_names = [p.name for p in py_stage.params]
0046 java_params = list(java_stage.params())
0047 java_param_names = [jp.name() for jp in java_params]
0048 test_self.assertEqual(
0049 param_names, sorted(java_param_names),
0050 "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s"
0051 % (py_stage_str, java_param_names, param_names))
0052 for p in py_stage.params:
0053 test_self.assertEqual(p.parent, py_stage.uid)
0054 java_param = java_stage.getParam(p.name)
0055 py_has_default = py_stage.hasDefault(p)
0056 java_has_default = java_stage.hasDefault(java_param)
0057 test_self.assertEqual(py_has_default, java_has_default,
0058 "Default value mismatch of param %s for Params %s"
0059 % (p.name, str(py_stage)))
0060 if py_has_default:
0061 if p.name == "seed":
0062 continue
0063 java_default = _java2py(test_self.sc,
0064 java_stage.clear(java_param).getOrDefault(java_param))
0065 py_stage.clear(p)
0066 py_default = py_stage.getOrDefault(p)
0067
0068 if isinstance(java_default, float) and np.isnan(java_default):
0069 java_default = "NaN"
0070 py_default = "NaN" if np.isnan(py_default) else "not NaN"
0071 test_self.assertEqual(
0072 java_default, py_default,
0073 "Java default %s != python default %s of param %s for Params %s"
0074 % (str(java_default), str(py_default), p.name, str(py_stage)))
0075
0076
0077 class SparkSessionTestCase(PySparkTestCase):
0078 @classmethod
0079 def setUpClass(cls):
0080 PySparkTestCase.setUpClass()
0081 cls.spark = SparkSession(cls.sc)
0082
0083 @classmethod
0084 def tearDownClass(cls):
0085 PySparkTestCase.tearDownClass()
0086 cls.spark.stop()
0087
0088
0089 class MockDataset(DataFrame):
0090
0091 def __init__(self):
0092 self.index = 0
0093
0094
0095 class HasFake(Params):
0096
0097 def __init__(self):
0098 super(HasFake, self).__init__()
0099 self.fake = Param(self, "fake", "fake param")
0100
0101 def getFake(self):
0102 return self.getOrDefault(self.fake)
0103
0104
0105 class MockTransformer(Transformer, HasFake):
0106
0107 def __init__(self):
0108 super(MockTransformer, self).__init__()
0109 self.dataset_index = None
0110
0111 def _transform(self, dataset):
0112 self.dataset_index = dataset.index
0113 dataset.index += 1
0114 return dataset
0115
0116
0117 class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
0118
0119 shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
0120 "data in a DataFrame",
0121 typeConverter=TypeConverters.toFloat)
0122
0123 def __init__(self, shiftVal=1):
0124 super(MockUnaryTransformer, self).__init__()
0125 self._setDefault(shift=1)
0126 self._set(shift=shiftVal)
0127
0128 def getShift(self):
0129 return self.getOrDefault(self.shift)
0130
0131 def setShift(self, shift):
0132 self._set(shift=shift)
0133
0134 def createTransformFunc(self):
0135 shiftVal = self.getShift()
0136 return lambda x: x + shiftVal
0137
0138 def outputDataType(self):
0139 return DoubleType()
0140
0141 def validateInputType(self, inputType):
0142 if inputType != DoubleType():
0143 raise TypeError("Bad input type: {}. ".format(inputType) +
0144 "Requires Double.")
0145
0146
0147 class MockEstimator(Estimator, HasFake):
0148
0149 def __init__(self):
0150 super(MockEstimator, self).__init__()
0151 self.dataset_index = None
0152
0153 def _fit(self, dataset):
0154 self.dataset_index = dataset.index
0155 model = MockModel()
0156 self._copyValues(model)
0157 return model
0158
0159
0160 class MockModel(MockTransformer, Model, HasFake):
0161 pass