Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 from __future__ import print_function
0019 
0020 header = """#
0021 # Licensed to the Apache Software Foundation (ASF) under one or more
0022 # contributor license agreements.  See the NOTICE file distributed with
0023 # this work for additional information regarding copyright ownership.
0024 # The ASF licenses this file to You under the Apache License, Version 2.0
0025 # (the "License"); you may not use this file except in compliance with
0026 # the License.  You may obtain a copy of the License at
0027 #
0028 #    http://www.apache.org/licenses/LICENSE-2.0
0029 #
0030 # Unless required by applicable law or agreed to in writing, software
0031 # distributed under the License is distributed on an "AS IS" BASIS,
0032 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0033 # See the License for the specific language governing permissions and
0034 # limitations under the License.
0035 #"""
0036 
0037 # Code generator for shared params (shared.py). Run under this folder with:
0038 # python _shared_params_code_gen.py > shared.py
0039 
0040 
0041 def _gen_param_header(name, doc, defaultValueStr, typeConverter):
0042     """
0043     Generates the header part for shared variables
0044 
0045     :param name: param name
0046     :param doc: param doc
0047     """
0048     template = '''class Has$Name(Params):
0049     """
0050     Mixin for param $name: $doc
0051     """
0052 
0053     $name = Param(Params._dummy(), "$name", "$doc", typeConverter=$typeConverter)
0054 
0055     def __init__(self):
0056         super(Has$Name, self).__init__()'''
0057 
0058     if defaultValueStr is not None:
0059         template += '''
0060         self._setDefault($name=$defaultValueStr)'''
0061 
0062     Name = name[0].upper() + name[1:]
0063     if typeConverter is None:
0064         typeConverter = str(None)
0065     return template \
0066         .replace("$name", name) \
0067         .replace("$Name", Name) \
0068         .replace("$doc", doc) \
0069         .replace("$defaultValueStr", str(defaultValueStr)) \
0070         .replace("$typeConverter", typeConverter)
0071 
0072 
0073 def _gen_param_code(name, doc, defaultValueStr):
0074     """
0075     Generates Python code for a shared param class.
0076 
0077     :param name: param name
0078     :param doc: param doc
0079     :param defaultValueStr: string representation of the default value
0080     :return: code string
0081     """
0082     # TODO: How to correctly inherit instance attributes?
0083     template = '''
0084     def get$Name(self):
0085         """
0086         Gets the value of $name or its default value.
0087         """
0088         return self.getOrDefault(self.$name)'''
0089 
0090     Name = name[0].upper() + name[1:]
0091     return template \
0092         .replace("$name", name) \
0093         .replace("$Name", Name) \
0094         .replace("$doc", doc) \
0095         .replace("$defaultValueStr", str(defaultValueStr))
0096 
0097 if __name__ == "__main__":
0098     print(header)
0099     print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
0100     print("from pyspark.ml.param import *\n\n")
0101     shared = [
0102         ("maxIter", "max number of iterations (>= 0).", None, "TypeConverters.toInt"),
0103         ("regParam", "regularization parameter (>= 0).", None, "TypeConverters.toFloat"),
0104         ("featuresCol", "features column name.", "'features'", "TypeConverters.toString"),
0105         ("labelCol", "label column name.", "'label'", "TypeConverters.toString"),
0106         ("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"),
0107         ("probabilityCol", "Column name for predicted class conditional probabilities. " +
0108          "Note: Not all models output well-calibrated probability estimates! These probabilities " +
0109          "should be treated as confidences, not precise probabilities.", "'probability'",
0110          "TypeConverters.toString"),
0111         ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'",
0112          "TypeConverters.toString"),
0113         ("inputCol", "input column name.", None, "TypeConverters.toString"),
0114         ("inputCols", "input column names.", None, "TypeConverters.toListString"),
0115         ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"),
0116         ("outputCols", "output column names.", None, "TypeConverters.toListString"),
0117         ("numFeatures", "Number of features. Should be greater than 0.", "262144",
0118          "TypeConverters.toInt"),
0119         ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " +
0120          "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " +
0121          "this setting will be ignored if the checkpoint directory is not set in the SparkContext.",
0122          None, "TypeConverters.toInt"),
0123         ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"),
0124         ("tol", "the convergence tolerance for iterative algorithms (>= 0).", None,
0125          "TypeConverters.toFloat"),
0126         ("relativeError", "the relative target precision for the approximate quantile " +
0127          "algorithm. Must be in the range [0, 1]", "0.001", "TypeConverters.toFloat"),
0128         ("stepSize", "Step size to be used for each iteration of optimization (>= 0).", None,
0129          "TypeConverters.toFloat"),
0130         ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
0131          "out rows with bad values), or error (which will throw an error). More options may be " +
0132          "added later.", None, "TypeConverters.toString"),
0133         ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
0134          "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0",
0135          "TypeConverters.toFloat"),
0136         ("fitIntercept", "whether to fit an intercept term.", "True", "TypeConverters.toBoolean"),
0137         ("standardization", "whether to standardize the training features before fitting the " +
0138          "model.", "True", "TypeConverters.toBoolean"),
0139         ("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
0140          "predicting each class. Array must have length equal to the number of classes, with " +
0141          "values > 0, excepting that at most one value may be 0. " +
0142          "The class with largest value p/t is predicted, where p is the original " +
0143          "probability of that class and t is the class's threshold.", None,
0144          "TypeConverters.toListFloat"),
0145         ("threshold", "threshold in binary classification prediction, in range [0, 1]",
0146          "0.5", "TypeConverters.toFloat"),
0147         ("weightCol", "weight column name. If this is not set or empty, we treat " +
0148          "all instance weights as 1.0.", None, "TypeConverters.toString"),
0149         ("solver", "the solver algorithm for optimization. If this is not set or empty, " +
0150          "default value is 'auto'.", "'auto'", "TypeConverters.toString"),
0151         ("varianceCol", "column name for the biased sample variance of prediction.",
0152          None, "TypeConverters.toString"),
0153         ("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2",
0154          "TypeConverters.toInt"),
0155         ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).",
0156          "1", "TypeConverters.toInt"),
0157         ("collectSubModels", "Param for whether to collect a list of sub-models trained during " +
0158          "tuning. If set to false, then only the single best sub-model will be available after " +
0159          "fitting. If set to true, then all sub-models will be available. Warning: For large " +
0160          "models, collecting all sub-models can cause OOMs on the Spark driver.",
0161          "False", "TypeConverters.toBoolean"),
0162         ("loss", "the loss function to be optimized.", None, "TypeConverters.toString"),
0163         ("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.",
0164          "'euclidean'", "TypeConverters.toString"),
0165         ("validationIndicatorCol", "name of the column that indicates whether each row is for " +
0166          "training or for validation. False indicates training; true indicates validation.",
0167          None, "TypeConverters.toString"),
0168         ("blockSize", "block size for stacking input data in matrices. Data is stacked within "
0169          "partitions. If block size is more than remaining data in a partition then it is "
0170          "adjusted to the size of this data.", None, "TypeConverters.toInt")]
0171 
0172     code = []
0173     for name, doc, defaultValueStr, typeConverter in shared:
0174         param_code = _gen_param_header(name, doc, defaultValueStr, typeConverter)
0175         code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr))
0176 
0177     print("\n\n\n".join(code))