Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import unittest
0018 
0019 from pyspark.ml.pipeline import Pipeline
0020 from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer, PySparkTestCase
0021 
0022 
0023 class PipelineTests(PySparkTestCase):
0024 
0025     def test_pipeline(self):
0026         dataset = MockDataset()
0027         estimator0 = MockEstimator()
0028         transformer1 = MockTransformer()
0029         estimator2 = MockEstimator()
0030         transformer3 = MockTransformer()
0031         pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3])
0032         pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
0033         model0, transformer1, model2, transformer3 = pipeline_model.stages
0034         self.assertEqual(0, model0.dataset_index)
0035         self.assertEqual(0, model0.getFake())
0036         self.assertEqual(1, transformer1.dataset_index)
0037         self.assertEqual(1, transformer1.getFake())
0038         self.assertEqual(2, dataset.index)
0039         self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.")
0040         self.assertIsNone(transformer3.dataset_index,
0041                           "The last transformer shouldn't be called in fit.")
0042         dataset = pipeline_model.transform(dataset)
0043         self.assertEqual(2, model0.dataset_index)
0044         self.assertEqual(3, transformer1.dataset_index)
0045         self.assertEqual(4, model2.dataset_index)
0046         self.assertEqual(5, transformer3.dataset_index)
0047         self.assertEqual(6, dataset.index)
0048 
0049     def test_identity_pipeline(self):
0050         dataset = MockDataset()
0051 
0052         def doTransform(pipeline):
0053             pipeline_model = pipeline.fit(dataset)
0054             return pipeline_model.transform(dataset)
0055         # check that empty pipeline did not perform any transformation
0056         self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
0057         # check that failure to set stages param will raise KeyError for missing param
0058         self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
0059 
0060 
0061 if __name__ == "__main__":
0062     from pyspark.ml.tests.test_pipeline import *
0063 
0064     try:
0065         import xmlrunner
0066         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0067     except ImportError:
0068         testRunner = None
0069     unittest.main(testRunner=testRunner, verbosity=2)