0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import unittest
0019
0020 from pyspark.ml.linalg import Vectors
0021 from pyspark.ml.stat import ChiSquareTest
0022 from pyspark.sql import DataFrame
0023 from pyspark.testing.mlutils import SparkSessionTestCase
0024
0025
0026 class ChiSquareTestTests(SparkSessionTestCase):
0027
0028 def test_chisquaretest(self):
0029 data = [[0, Vectors.dense([0, 1, 2])],
0030 [1, Vectors.dense([1, 1, 1])],
0031 [2, Vectors.dense([2, 1, 0])]]
0032 df = self.spark.createDataFrame(data, ['label', 'feat'])
0033 res = ChiSquareTest.test(df, 'feat', 'label')
0034
0035
0036 self.assertIsInstance(res, DataFrame)
0037 fieldNames = set(field.name for field in res.schema.fields)
0038 expectedFields = ["pValues", "degreesOfFreedom", "statistics"]
0039 self.assertTrue(all(field in fieldNames for field in expectedFields))
0040
0041
0042 if __name__ == "__main__":
0043 from pyspark.ml.tests.test_stat import *
0044
0045 try:
0046 import xmlrunner
0047 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0048 except ImportError:
0049 testRunner = None
0050 unittest.main(testRunner=testRunner, verbosity=2)