Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import json
0019 from shutil import rmtree
0020 import tempfile
0021 import unittest
0022 
0023 from pyspark.ml import Transformer
0024 from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, OneVsRest, \
0025     OneVsRestModel
0026 from pyspark.ml.clustering import KMeans
0027 from pyspark.ml.feature import Binarizer, HashingTF, PCA
0028 from pyspark.ml.linalg import Vectors
0029 from pyspark.ml.param import Params
0030 from pyspark.ml.pipeline import Pipeline, PipelineModel
0031 from pyspark.ml.regression import DecisionTreeRegressor, LinearRegression
0032 from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWriter
0033 from pyspark.ml.wrapper import JavaParams
0034 from pyspark.testing.mlutils import MockUnaryTransformer, SparkSessionTestCase
0035 
0036 
0037 class PersistenceTest(SparkSessionTestCase):
0038 
0039     def test_linear_regression(self):
0040         lr = LinearRegression(maxIter=1)
0041         path = tempfile.mkdtemp()
0042         lr_path = path + "/lr"
0043         lr.save(lr_path)
0044         lr2 = LinearRegression.load(lr_path)
0045         self.assertEqual(lr.uid, lr2.uid)
0046         self.assertEqual(type(lr.uid), type(lr2.uid))
0047         self.assertEqual(lr2.uid, lr2.maxIter.parent,
0048                          "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
0049                          % (lr2.uid, lr2.maxIter.parent))
0050         self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
0051                          "Loaded LinearRegression instance default params did not match " +
0052                          "original defaults")
0053         try:
0054             rmtree(path)
0055         except OSError:
0056             pass
0057 
0058     def test_linear_regression_pmml_basic(self):
0059         # Most of the validation is done in the Scala side, here we just check
0060         # that we output text rather than parquet (e.g. that the format flag
0061         # was respected).
0062         df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
0063                                          (0.0, 2.0, Vectors.sparse(1, [], []))],
0064                                         ["label", "weight", "features"])
0065         lr = LinearRegression(maxIter=1)
0066         model = lr.fit(df)
0067         path = tempfile.mkdtemp()
0068         lr_path = path + "/lr-pmml"
0069         model.write().format("pmml").save(lr_path)
0070         pmml_text_list = self.sc.textFile(lr_path).collect()
0071         pmml_text = "\n".join(pmml_text_list)
0072         self.assertIn("Apache Spark", pmml_text)
0073         self.assertIn("PMML", pmml_text)
0074 
0075     def test_logistic_regression(self):
0076         lr = LogisticRegression(maxIter=1)
0077         path = tempfile.mkdtemp()
0078         lr_path = path + "/logreg"
0079         lr.save(lr_path)
0080         lr2 = LogisticRegression.load(lr_path)
0081         self.assertEqual(lr2.uid, lr2.maxIter.parent,
0082                          "Loaded LogisticRegression instance uid (%s) "
0083                          "did not match Param's uid (%s)"
0084                          % (lr2.uid, lr2.maxIter.parent))
0085         self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
0086                          "Loaded LogisticRegression instance default params did not match " +
0087                          "original defaults")
0088         try:
0089             rmtree(path)
0090         except OSError:
0091             pass
0092 
0093     def test_kmeans(self):
0094         kmeans = KMeans(k=2, seed=1)
0095         path = tempfile.mkdtemp()
0096         km_path = path + "/km"
0097         kmeans.save(km_path)
0098         kmeans2 = KMeans.load(km_path)
0099         self.assertEqual(kmeans.uid, kmeans2.uid)
0100         self.assertEqual(type(kmeans.uid), type(kmeans2.uid))
0101         self.assertEqual(kmeans2.uid, kmeans2.k.parent,
0102                          "Loaded KMeans instance uid (%s) did not match Param's uid (%s)"
0103                          % (kmeans2.uid, kmeans2.k.parent))
0104         self.assertEqual(kmeans._defaultParamMap[kmeans.k], kmeans2._defaultParamMap[kmeans2.k],
0105                          "Loaded KMeans instance default params did not match " +
0106                          "original defaults")
0107         try:
0108             rmtree(path)
0109         except OSError:
0110             pass
0111 
0112     def test_kmean_pmml_basic(self):
0113         # Most of the validation is done in the Scala side, here we just check
0114         # that we output text rather than parquet (e.g. that the format flag
0115         # was respected).
0116         data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
0117                 (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
0118         df = self.spark.createDataFrame(data, ["features"])
0119         kmeans = KMeans(k=2, seed=1)
0120         model = kmeans.fit(df)
0121         path = tempfile.mkdtemp()
0122         km_path = path + "/km-pmml"
0123         model.write().format("pmml").save(km_path)
0124         pmml_text_list = self.sc.textFile(km_path).collect()
0125         pmml_text = "\n".join(pmml_text_list)
0126         self.assertIn("Apache Spark", pmml_text)
0127         self.assertIn("PMML", pmml_text)
0128 
0129     def _compare_params(self, m1, m2, param):
0130         """
0131         Compare 2 ML Params instances for the given param, and assert both have the same param value
0132         and parent. The param must be a parameter of m1.
0133         """
0134         # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap.
0135         if m1.isDefined(param):
0136             paramValue1 = m1.getOrDefault(param)
0137             paramValue2 = m2.getOrDefault(m2.getParam(param.name))
0138             if isinstance(paramValue1, Params):
0139                 self._compare_pipelines(paramValue1, paramValue2)
0140             else:
0141                 self.assertEqual(paramValue1, paramValue2)  # for general types param
0142             # Assert parents are equal
0143             self.assertEqual(param.parent, m2.getParam(param.name).parent)
0144         else:
0145             # If m1 is not defined param, then m2 should not, too. See SPARK-14931.
0146             self.assertFalse(m2.isDefined(m2.getParam(param.name)))
0147 
0148     def _compare_pipelines(self, m1, m2):
0149         """
0150         Compare 2 ML types, asserting that they are equivalent.
0151         This currently supports:
0152          - basic types
0153          - Pipeline, PipelineModel
0154          - OneVsRest, OneVsRestModel
0155         This checks:
0156          - uid
0157          - type
0158          - Param values and parents
0159         """
0160         self.assertEqual(m1.uid, m2.uid)
0161         self.assertEqual(type(m1), type(m2))
0162         if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
0163             self.assertEqual(len(m1.params), len(m2.params))
0164             for p in m1.params:
0165                 self._compare_params(m1, m2, p)
0166         elif isinstance(m1, Pipeline):
0167             self.assertEqual(len(m1.getStages()), len(m2.getStages()))
0168             for s1, s2 in zip(m1.getStages(), m2.getStages()):
0169                 self._compare_pipelines(s1, s2)
0170         elif isinstance(m1, PipelineModel):
0171             self.assertEqual(len(m1.stages), len(m2.stages))
0172             for s1, s2 in zip(m1.stages, m2.stages):
0173                 self._compare_pipelines(s1, s2)
0174         elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel):
0175             for p in m1.params:
0176                 self._compare_params(m1, m2, p)
0177             if isinstance(m1, OneVsRestModel):
0178                 self.assertEqual(len(m1.models), len(m2.models))
0179                 for x, y in zip(m1.models, m2.models):
0180                     self._compare_pipelines(x, y)
0181         else:
0182             raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
0183 
0184     def test_pipeline_persistence(self):
0185         """
0186         Pipeline[HashingTF, PCA]
0187         """
0188         temp_path = tempfile.mkdtemp()
0189 
0190         try:
0191             df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
0192             tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
0193             pca = PCA(k=2, inputCol="features", outputCol="pca_features")
0194             pl = Pipeline(stages=[tf, pca])
0195             model = pl.fit(df)
0196 
0197             pipeline_path = temp_path + "/pipeline"
0198             pl.save(pipeline_path)
0199             loaded_pipeline = Pipeline.load(pipeline_path)
0200             self._compare_pipelines(pl, loaded_pipeline)
0201 
0202             model_path = temp_path + "/pipeline-model"
0203             model.save(model_path)
0204             loaded_model = PipelineModel.load(model_path)
0205             self._compare_pipelines(model, loaded_model)
0206         finally:
0207             try:
0208                 rmtree(temp_path)
0209             except OSError:
0210                 pass
0211 
0212     def test_nested_pipeline_persistence(self):
0213         """
0214         Pipeline[HashingTF, Pipeline[PCA]]
0215         """
0216         temp_path = tempfile.mkdtemp()
0217 
0218         try:
0219             df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
0220             tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
0221             pca = PCA(k=2, inputCol="features", outputCol="pca_features")
0222             p0 = Pipeline(stages=[pca])
0223             pl = Pipeline(stages=[tf, p0])
0224             model = pl.fit(df)
0225 
0226             pipeline_path = temp_path + "/pipeline"
0227             pl.save(pipeline_path)
0228             loaded_pipeline = Pipeline.load(pipeline_path)
0229             self._compare_pipelines(pl, loaded_pipeline)
0230 
0231             model_path = temp_path + "/pipeline-model"
0232             model.save(model_path)
0233             loaded_model = PipelineModel.load(model_path)
0234             self._compare_pipelines(model, loaded_model)
0235         finally:
0236             try:
0237                 rmtree(temp_path)
0238             except OSError:
0239                 pass
0240 
0241     def test_python_transformer_pipeline_persistence(self):
0242         """
0243         Pipeline[MockUnaryTransformer, Binarizer]
0244         """
0245         temp_path = tempfile.mkdtemp()
0246 
0247         try:
0248             df = self.spark.range(0, 10).toDF('input')
0249             tf = MockUnaryTransformer(shiftVal=2)\
0250                 .setInputCol("input").setOutputCol("shiftedInput")
0251             tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
0252             pl = Pipeline(stages=[tf, tf2])
0253             model = pl.fit(df)
0254 
0255             pipeline_path = temp_path + "/pipeline"
0256             pl.save(pipeline_path)
0257             loaded_pipeline = Pipeline.load(pipeline_path)
0258             self._compare_pipelines(pl, loaded_pipeline)
0259 
0260             model_path = temp_path + "/pipeline-model"
0261             model.save(model_path)
0262             loaded_model = PipelineModel.load(model_path)
0263             self._compare_pipelines(model, loaded_model)
0264         finally:
0265             try:
0266                 rmtree(temp_path)
0267             except OSError:
0268                 pass
0269 
0270     def test_onevsrest(self):
0271         temp_path = tempfile.mkdtemp()
0272         df = self.spark.createDataFrame([(0.0, 0.5, Vectors.dense(1.0, 0.8)),
0273                                          (1.0, 0.5, Vectors.sparse(2, [], [])),
0274                                          (2.0, 1.0, Vectors.dense(0.5, 0.5))] * 10,
0275                                         ["label", "wt", "features"])
0276 
0277         lr = LogisticRegression(maxIter=5, regParam=0.01)
0278         ovr = OneVsRest(classifier=lr)
0279 
0280         def reload_and_compare(ovr, suffix):
0281             model = ovr.fit(df)
0282             ovrPath = temp_path + "/{}".format(suffix)
0283             ovr.save(ovrPath)
0284             loadedOvr = OneVsRest.load(ovrPath)
0285             self._compare_pipelines(ovr, loadedOvr)
0286             modelPath = temp_path + "/{}Model".format(suffix)
0287             model.save(modelPath)
0288             loadedModel = OneVsRestModel.load(modelPath)
0289             self._compare_pipelines(model, loadedModel)
0290 
0291         reload_and_compare(OneVsRest(classifier=lr), "ovr")
0292         reload_and_compare(OneVsRest(classifier=lr).setWeightCol("wt"), "ovrw")
0293 
0294     def test_decisiontree_classifier(self):
0295         dt = DecisionTreeClassifier(maxDepth=1)
0296         path = tempfile.mkdtemp()
0297         dtc_path = path + "/dtc"
0298         dt.save(dtc_path)
0299         dt2 = DecisionTreeClassifier.load(dtc_path)
0300         self.assertEqual(dt2.uid, dt2.maxDepth.parent,
0301                          "Loaded DecisionTreeClassifier instance uid (%s) "
0302                          "did not match Param's uid (%s)"
0303                          % (dt2.uid, dt2.maxDepth.parent))
0304         self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
0305                          "Loaded DecisionTreeClassifier instance default params did not match " +
0306                          "original defaults")
0307         try:
0308             rmtree(path)
0309         except OSError:
0310             pass
0311 
0312     def test_decisiontree_regressor(self):
0313         dt = DecisionTreeRegressor(maxDepth=1)
0314         path = tempfile.mkdtemp()
0315         dtr_path = path + "/dtr"
0316         dt.save(dtr_path)
0317         dt2 = DecisionTreeClassifier.load(dtr_path)
0318         self.assertEqual(dt2.uid, dt2.maxDepth.parent,
0319                          "Loaded DecisionTreeRegressor instance uid (%s) "
0320                          "did not match Param's uid (%s)"
0321                          % (dt2.uid, dt2.maxDepth.parent))
0322         self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
0323                          "Loaded DecisionTreeRegressor instance default params did not match " +
0324                          "original defaults")
0325         try:
0326             rmtree(path)
0327         except OSError:
0328             pass
0329 
0330     def test_default_read_write(self):
0331         temp_path = tempfile.mkdtemp()
0332 
0333         lr = LogisticRegression()
0334         lr.setMaxIter(50)
0335         lr.setThreshold(.75)
0336         writer = DefaultParamsWriter(lr)
0337 
0338         savePath = temp_path + "/lr"
0339         writer.save(savePath)
0340 
0341         reader = DefaultParamsReadable.read()
0342         lr2 = reader.load(savePath)
0343 
0344         self.assertEqual(lr.uid, lr2.uid)
0345         self.assertEqual(lr.extractParamMap(), lr2.extractParamMap())
0346 
0347         # test overwrite
0348         lr.setThreshold(.8)
0349         writer.overwrite().save(savePath)
0350 
0351         reader = DefaultParamsReadable.read()
0352         lr3 = reader.load(savePath)
0353 
0354         self.assertEqual(lr.uid, lr3.uid)
0355         self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
0356 
0357     def test_default_read_write_default_params(self):
0358         lr = LogisticRegression()
0359         self.assertFalse(lr.isSet(lr.getParam("threshold")))
0360 
0361         lr.setMaxIter(50)
0362         lr.setThreshold(.75)
0363 
0364         # `threshold` is set by user, default param `predictionCol` is not set by user.
0365         self.assertTrue(lr.isSet(lr.getParam("threshold")))
0366         self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
0367         self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
0368 
0369         writer = DefaultParamsWriter(lr)
0370         metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
0371         self.assertTrue("defaultParamMap" in metadata)
0372 
0373         reader = DefaultParamsReadable.read()
0374         metadataStr = json.dumps(metadata, separators=[',',  ':'])
0375         loadedMetadata = reader._parseMetaData(metadataStr, )
0376         reader.getAndSetParams(lr, loadedMetadata)
0377 
0378         self.assertTrue(lr.isSet(lr.getParam("threshold")))
0379         self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
0380         self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
0381 
0382         # manually create metadata without `defaultParamMap` section.
0383         del metadata['defaultParamMap']
0384         metadataStr = json.dumps(metadata, separators=[',',  ':'])
0385         loadedMetadata = reader._parseMetaData(metadataStr, )
0386         with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
0387             reader.getAndSetParams(lr, loadedMetadata)
0388 
0389         # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
0390         metadata['sparkVersion'] = '2.3.0'
0391         metadataStr = json.dumps(metadata, separators=[',',  ':'])
0392         loadedMetadata = reader._parseMetaData(metadataStr, )
0393         reader.getAndSetParams(lr, loadedMetadata)
0394 
0395 
0396 if __name__ == "__main__":
0397     from pyspark.ml.tests.test_persistence import *
0398 
0399     try:
0400         import xmlrunner
0401         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0402     except ImportError:
0403         testRunner = None
0404     unittest.main(testRunner=testRunner, verbosity=2)