Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 # http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import numpy as np
0019 
0020 from pyspark.ml import Estimator, Model, Transformer, UnaryTransformer
0021 from pyspark.ml.param import Param, Params, TypeConverters
0022 from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
0023 from pyspark.ml.wrapper import _java2py
0024 from pyspark.sql import DataFrame, SparkSession
0025 from pyspark.sql.types import DoubleType
0026 from pyspark.testing.utils import ReusedPySparkTestCase as PySparkTestCase
0027 
0028 
0029 def check_params(test_self, py_stage, check_params_exist=True):
0030     """
0031     Checks common requirements for Params.params:
0032       - set of params exist in Java and Python and are ordered by names
0033       - param parent has the same UID as the object's UID
0034       - default param value from Java matches value in Python
0035       - optionally check if all params from Java also exist in Python
0036     """
0037     py_stage_str = "%s %s" % (type(py_stage), py_stage)
0038     if not hasattr(py_stage, "_to_java"):
0039         return
0040     java_stage = py_stage._to_java()
0041     if java_stage is None:
0042         return
0043     test_self.assertEqual(py_stage.uid, java_stage.uid(), msg=py_stage_str)
0044     if check_params_exist:
0045         param_names = [p.name for p in py_stage.params]
0046         java_params = list(java_stage.params())
0047         java_param_names = [jp.name() for jp in java_params]
0048         test_self.assertEqual(
0049             param_names, sorted(java_param_names),
0050             "Param list in Python does not match Java for %s:\nJava = %s\nPython = %s"
0051             % (py_stage_str, java_param_names, param_names))
0052     for p in py_stage.params:
0053         test_self.assertEqual(p.parent, py_stage.uid)
0054         java_param = java_stage.getParam(p.name)
0055         py_has_default = py_stage.hasDefault(p)
0056         java_has_default = java_stage.hasDefault(java_param)
0057         test_self.assertEqual(py_has_default, java_has_default,
0058                               "Default value mismatch of param %s for Params %s"
0059                               % (p.name, str(py_stage)))
0060         if py_has_default:
0061             if p.name == "seed":
0062                 continue  # Random seeds between Spark and PySpark are different
0063             java_default = _java2py(test_self.sc,
0064                                     java_stage.clear(java_param).getOrDefault(java_param))
0065             py_stage.clear(p)
0066             py_default = py_stage.getOrDefault(p)
0067             # equality test for NaN is always False
0068             if isinstance(java_default, float) and np.isnan(java_default):
0069                 java_default = "NaN"
0070                 py_default = "NaN" if np.isnan(py_default) else "not NaN"
0071             test_self.assertEqual(
0072                 java_default, py_default,
0073                 "Java default %s != python default %s of param %s for Params %s"
0074                 % (str(java_default), str(py_default), p.name, str(py_stage)))
0075 
0076 
0077 class SparkSessionTestCase(PySparkTestCase):
0078     @classmethod
0079     def setUpClass(cls):
0080         PySparkTestCase.setUpClass()
0081         cls.spark = SparkSession(cls.sc)
0082 
0083     @classmethod
0084     def tearDownClass(cls):
0085         PySparkTestCase.tearDownClass()
0086         cls.spark.stop()
0087 
0088 
0089 class MockDataset(DataFrame):
0090 
0091     def __init__(self):
0092         self.index = 0
0093 
0094 
0095 class HasFake(Params):
0096 
0097     def __init__(self):
0098         super(HasFake, self).__init__()
0099         self.fake = Param(self, "fake", "fake param")
0100 
0101     def getFake(self):
0102         return self.getOrDefault(self.fake)
0103 
0104 
0105 class MockTransformer(Transformer, HasFake):
0106 
0107     def __init__(self):
0108         super(MockTransformer, self).__init__()
0109         self.dataset_index = None
0110 
0111     def _transform(self, dataset):
0112         self.dataset_index = dataset.index
0113         dataset.index += 1
0114         return dataset
0115 
0116 
0117 class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
0118 
0119     shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
0120                   "data in a DataFrame",
0121                   typeConverter=TypeConverters.toFloat)
0122 
0123     def __init__(self, shiftVal=1):
0124         super(MockUnaryTransformer, self).__init__()
0125         self._setDefault(shift=1)
0126         self._set(shift=shiftVal)
0127 
0128     def getShift(self):
0129         return self.getOrDefault(self.shift)
0130 
0131     def setShift(self, shift):
0132         self._set(shift=shift)
0133 
0134     def createTransformFunc(self):
0135         shiftVal = self.getShift()
0136         return lambda x: x + shiftVal
0137 
0138     def outputDataType(self):
0139         return DoubleType()
0140 
0141     def validateInputType(self, inputType):
0142         if inputType != DoubleType():
0143             raise TypeError("Bad input type: {}. ".format(inputType) +
0144                             "Requires Double.")
0145 
0146 
0147 class MockEstimator(Estimator, HasFake):
0148 
0149     def __init__(self):
0150         super(MockEstimator, self).__init__()
0151         self.dataset_index = None
0152 
0153     def _fit(self, dataset):
0154         self.dataset_index = dataset.index
0155         model = MockModel()
0156         self._copyValues(model)
0157         return model
0158 
0159 
0160 class MockModel(MockTransformer, Model, HasFake):
0161     pass