0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import unittest
0018
0019 from pyspark.ml.pipeline import Pipeline
0020 from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer, PySparkTestCase
0021
0022
0023 class PipelineTests(PySparkTestCase):
0024
0025 def test_pipeline(self):
0026 dataset = MockDataset()
0027 estimator0 = MockEstimator()
0028 transformer1 = MockTransformer()
0029 estimator2 = MockEstimator()
0030 transformer3 = MockTransformer()
0031 pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3])
0032 pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
0033 model0, transformer1, model2, transformer3 = pipeline_model.stages
0034 self.assertEqual(0, model0.dataset_index)
0035 self.assertEqual(0, model0.getFake())
0036 self.assertEqual(1, transformer1.dataset_index)
0037 self.assertEqual(1, transformer1.getFake())
0038 self.assertEqual(2, dataset.index)
0039 self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.")
0040 self.assertIsNone(transformer3.dataset_index,
0041 "The last transformer shouldn't be called in fit.")
0042 dataset = pipeline_model.transform(dataset)
0043 self.assertEqual(2, model0.dataset_index)
0044 self.assertEqual(3, transformer1.dataset_index)
0045 self.assertEqual(4, model2.dataset_index)
0046 self.assertEqual(5, transformer3.dataset_index)
0047 self.assertEqual(6, dataset.index)
0048
0049 def test_identity_pipeline(self):
0050 dataset = MockDataset()
0051
0052 def doTransform(pipeline):
0053 pipeline_model = pipeline.fit(dataset)
0054 return pipeline_model.transform(dataset)
0055
0056 self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
0057
0058 self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
0059
0060
0061 if __name__ == "__main__":
0062 from pyspark.ml.tests.test_pipeline import *
0063
0064 try:
0065 import xmlrunner
0066 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0067 except ImportError:
0068 testRunner = None
0069 unittest.main(testRunner=testRunner, verbosity=2)