0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from __future__ import print_function
0019
0020 header = """#
0021 # Licensed to the Apache Software Foundation (ASF) under one or more
0022 # contributor license agreements. See the NOTICE file distributed with
0023 # this work for additional information regarding copyright ownership.
0024 # The ASF licenses this file to You under the Apache License, Version 2.0
0025 # (the "License"); you may not use this file except in compliance with
0026 # the License. You may obtain a copy of the License at
0027 #
0028 # http://www.apache.org/licenses/LICENSE-2.0
0029 #
0030 # Unless required by applicable law or agreed to in writing, software
0031 # distributed under the License is distributed on an "AS IS" BASIS,
0032 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0033 # See the License for the specific language governing permissions and
0034 # limitations under the License.
0035 #"""
0036
0037
0038
0039
0040
0041 def _gen_param_header(name, doc, defaultValueStr, typeConverter):
0042 """
0043 Generates the header part for shared variables
0044
0045 :param name: param name
0046 :param doc: param doc
0047 """
0048 template = '''class Has$Name(Params):
0049 """
0050 Mixin for param $name: $doc
0051 """
0052
0053 $name = Param(Params._dummy(), "$name", "$doc", typeConverter=$typeConverter)
0054
0055 def __init__(self):
0056 super(Has$Name, self).__init__()'''
0057
0058 if defaultValueStr is not None:
0059 template += '''
0060 self._setDefault($name=$defaultValueStr)'''
0061
0062 Name = name[0].upper() + name[1:]
0063 if typeConverter is None:
0064 typeConverter = str(None)
0065 return template \
0066 .replace("$name", name) \
0067 .replace("$Name", Name) \
0068 .replace("$doc", doc) \
0069 .replace("$defaultValueStr", str(defaultValueStr)) \
0070 .replace("$typeConverter", typeConverter)
0071
0072
0073 def _gen_param_code(name, doc, defaultValueStr):
0074 """
0075 Generates Python code for a shared param class.
0076
0077 :param name: param name
0078 :param doc: param doc
0079 :param defaultValueStr: string representation of the default value
0080 :return: code string
0081 """
0082
0083 template = '''
0084 def get$Name(self):
0085 """
0086 Gets the value of $name or its default value.
0087 """
0088 return self.getOrDefault(self.$name)'''
0089
0090 Name = name[0].upper() + name[1:]
0091 return template \
0092 .replace("$name", name) \
0093 .replace("$Name", Name) \
0094 .replace("$doc", doc) \
0095 .replace("$defaultValueStr", str(defaultValueStr))
0096
0097 if __name__ == "__main__":
0098 print(header)
0099 print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
0100 print("from pyspark.ml.param import *\n\n")
0101 shared = [
0102 ("maxIter", "max number of iterations (>= 0).", None, "TypeConverters.toInt"),
0103 ("regParam", "regularization parameter (>= 0).", None, "TypeConverters.toFloat"),
0104 ("featuresCol", "features column name.", "'features'", "TypeConverters.toString"),
0105 ("labelCol", "label column name.", "'label'", "TypeConverters.toString"),
0106 ("predictionCol", "prediction column name.", "'prediction'", "TypeConverters.toString"),
0107 ("probabilityCol", "Column name for predicted class conditional probabilities. " +
0108 "Note: Not all models output well-calibrated probability estimates! These probabilities " +
0109 "should be treated as confidences, not precise probabilities.", "'probability'",
0110 "TypeConverters.toString"),
0111 ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'",
0112 "TypeConverters.toString"),
0113 ("inputCol", "input column name.", None, "TypeConverters.toString"),
0114 ("inputCols", "input column names.", None, "TypeConverters.toListString"),
0115 ("outputCol", "output column name.", "self.uid + '__output'", "TypeConverters.toString"),
0116 ("outputCols", "output column names.", None, "TypeConverters.toListString"),
0117 ("numFeatures", "Number of features. Should be greater than 0.", "262144",
0118 "TypeConverters.toInt"),
0119 ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " +
0120 "E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: " +
0121 "this setting will be ignored if the checkpoint directory is not set in the SparkContext.",
0122 None, "TypeConverters.toInt"),
0123 ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"),
0124 ("tol", "the convergence tolerance for iterative algorithms (>= 0).", None,
0125 "TypeConverters.toFloat"),
0126 ("relativeError", "the relative target precision for the approximate quantile " +
0127 "algorithm. Must be in the range [0, 1]", "0.001", "TypeConverters.toFloat"),
0128 ("stepSize", "Step size to be used for each iteration of optimization (>= 0).", None,
0129 "TypeConverters.toFloat"),
0130 ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
0131 "out rows with bad values), or error (which will throw an error). More options may be " +
0132 "added later.", None, "TypeConverters.toString"),
0133 ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
0134 "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0",
0135 "TypeConverters.toFloat"),
0136 ("fitIntercept", "whether to fit an intercept term.", "True", "TypeConverters.toBoolean"),
0137 ("standardization", "whether to standardize the training features before fitting the " +
0138 "model.", "True", "TypeConverters.toBoolean"),
0139 ("thresholds", "Thresholds in multi-class classification to adjust the probability of " +
0140 "predicting each class. Array must have length equal to the number of classes, with " +
0141 "values > 0, excepting that at most one value may be 0. " +
0142 "The class with largest value p/t is predicted, where p is the original " +
0143 "probability of that class and t is the class's threshold.", None,
0144 "TypeConverters.toListFloat"),
0145 ("threshold", "threshold in binary classification prediction, in range [0, 1]",
0146 "0.5", "TypeConverters.toFloat"),
0147 ("weightCol", "weight column name. If this is not set or empty, we treat " +
0148 "all instance weights as 1.0.", None, "TypeConverters.toString"),
0149 ("solver", "the solver algorithm for optimization. If this is not set or empty, " +
0150 "default value is 'auto'.", "'auto'", "TypeConverters.toString"),
0151 ("varianceCol", "column name for the biased sample variance of prediction.",
0152 None, "TypeConverters.toString"),
0153 ("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2",
0154 "TypeConverters.toInt"),
0155 ("parallelism", "the number of threads to use when running parallel algorithms (>= 1).",
0156 "1", "TypeConverters.toInt"),
0157 ("collectSubModels", "Param for whether to collect a list of sub-models trained during " +
0158 "tuning. If set to false, then only the single best sub-model will be available after " +
0159 "fitting. If set to true, then all sub-models will be available. Warning: For large " +
0160 "models, collecting all sub-models can cause OOMs on the Spark driver.",
0161 "False", "TypeConverters.toBoolean"),
0162 ("loss", "the loss function to be optimized.", None, "TypeConverters.toString"),
0163 ("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.",
0164 "'euclidean'", "TypeConverters.toString"),
0165 ("validationIndicatorCol", "name of the column that indicates whether each row is for " +
0166 "training or for validation. False indicates training; true indicates validation.",
0167 None, "TypeConverters.toString"),
0168 ("blockSize", "block size for stacking input data in matrices. Data is stacked within "
0169 "partitions. If block size is more than remaining data in a partition then it is "
0170 "adjusted to the size of this data.", None, "TypeConverters.toInt")]
0171
0172 code = []
0173 for name, doc, defaultValueStr, typeConverter in shared:
0174 param_code = _gen_param_header(name, doc, defaultValueStr, typeConverter)
0175 code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr))
0176
0177 print("\n\n\n".join(code))