Skip to content

Commit 17ecfb9

Browse files
committed
code gen for shared params
1 parent d9ea77c commit 17ecfb9

File tree

7 files changed

+262
-34
lines changed

7 files changed

+262
-34
lines changed

python/docs/pyspark.ml.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ pyspark.ml.classification module
3434
.. automodule:: pyspark.ml.classification
3535
:members:
3636
:undoc-members:
37+
:inherited-members:
3738
:show-inheritance:

python/pyspark/ml/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ def _jvm():
2929
return SparkContext._jvm
3030

3131

32+
def _inherit_doc(cls):
33+
for name, func in vars(cls).items():
34+
# only inherit docstring for public functions
35+
if name.startswith("_"):
36+
continue
37+
if not func.__doc__:
38+
for parent in cls.__bases__:
39+
parent_func = getattr(parent, name, None)
40+
if parent_func and getattr(parent_func, "__doc__", None):
41+
func.__doc__ = parent_func.__doc__
42+
break
43+
return cls
44+
45+
3246
@inherit_doc
3347
class PipelineStage(Params):
3448
"""

python/pyspark/ml/classification.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,22 @@
1717

1818
from pyspark.sql import SchemaRDD, inherit_doc
1919
from pyspark.ml import Estimator, Transformer, _jvm
20-
from pyspark.ml.param import Param
20+
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
21+
HasRegParam
2122

2223

2324
@inherit_doc
24-
class LogisticRegression(Estimator):
25+
class LogisticRegression(Estimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
26+
HasRegParam):
2527
"""
2628
Logistic regression.
2729
"""
2830

2931
# _java_class = "org.apache.spark.ml.classification.LogisticRegression"
3032

3133
def __init__(self):
34+
super(LogisticRegression, self).__init__()
3235
self._java_obj = _jvm().org.apache.spark.ml.classification.LogisticRegression()
33-
self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
34-
self.regParam = Param(self, "regParam", "regularization constant", 0.1)
35-
self.featuresCol = Param(self, "featuresCol", "features column name", "features")
36-
37-
def setMaxIter(self, value):
38-
self._java_obj.setMaxIter(value)
39-
return self
40-
41-
def getMaxIter(self):
42-
return self._java_obj.getMaxIter()
43-
44-
def setRegParam(self, value):
45-
self._java_obj.setRegParam(value)
46-
return self
47-
48-
def getRegParam(self):
49-
return self._java_obj.getRegParam()
50-
51-
def setFeaturesCol(self, value):
52-
self._java_obj.setFeaturesCol(value)
53-
return self
54-
55-
def getFeaturesCol(self):
56-
return self._java_obj.getFeaturesCol()
5736

5837
def fit(self, dataset, params=None):
5938
"""
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
from pyspark.ml.util import Identifiable
2121

22-
2322
__all__ = ["Param"]
2423

2524

@@ -29,16 +28,18 @@ class Param(object):
2928
"""
3029

3130
def __init__(self, parent, name, doc, defaultValue=None):
31+
if not isinstance(parent, Identifiable):
32+
raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
3233
self.parent = parent
33-
self.name = name
34-
self.doc = doc
34+
self.name = str(name)
35+
self.doc = str(doc)
3536
self.defaultValue = defaultValue
3637

3738
def __str__(self):
38-
return self.parent + "_" + self.name
39+
return str(self.parent) + "_" + self.name
3940

4041
def __repr__(self):
41-
return self.parent + "_" + self.name
42+
return str(self.parent) + "_" + self.name
4243

4344

4445
class Params(Identifiable):
@@ -49,15 +50,16 @@ class Params(Identifiable):
4950

5051
__metaclass__ = ABCMeta
5152

53+
#: Internal param map.
54+
paramMap = {}
55+
5256
def __init__(self):
5357
super(Params, self).__init__()
54-
#: Internal param map.
55-
self.paramMap = {}
5658

5759
def params(self):
5860
"""
5961
Returns all params. The default implementation uses
6062
:py:func:`dir` to get all attributes of type
6163
:py:class:`Param`.
6264
"""
63-
return [attr for attr in dir(self) if isinstance(attr, Param)]
65+
return filter(lambda x: isinstance(x, Param), map(lambda x: getattr(self, x), dir(self)))
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
header = """#
19+
# Licensed to the Apache Software Foundation (ASF) under one or more
20+
# contributor license agreements. See the NOTICE file distributed with
21+
# this work for additional information regarding copyright ownership.
22+
# The ASF licenses this file to You under the Apache License, Version 2.0
23+
# (the "License"); you may not use this file except in compliance with
24+
# the License. You may obtain a copy of the License at
25+
#
26+
# http://www.apache.org/licenses/LICENSE-2.0
27+
#
28+
# Unless required by applicable law or agreed to in writing, software
29+
# distributed under the License is distributed on an "AS IS" BASIS,
30+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31+
# See the License for the specific language governing permissions and
32+
# limitations under the License.
33+
#"""
34+
35+
36+
def _gen_param_code(name, doc, defaultValue):
37+
"""
38+
Generates Python code for a shared param class.
39+
40+
:param name: param name
41+
:param doc: param doc
42+
:param defaultValue: string representation of the param
43+
:return: code string
44+
"""
45+
upperCamelName = name[0].upper() + name[1:]
46+
return """class Has%s(Params):
47+
48+
def __init__(self):
49+
super(Has%s, self).__init__()
50+
#: %s
51+
self.%s = Param(self, "%s", "%s", %s)
52+
53+
def set%s(self, value):
54+
self.paramMap[self.%s] = value
55+
return self
56+
57+
def get%s(self, value):
58+
if self.%s in self.paramMap:
59+
return self.paramMap[self.%s]
60+
else:
61+
return self.defaultValue""" % (
62+
upperCamelName, upperCamelName, doc, name, name, doc, defaultValue, upperCamelName, name,
63+
upperCamelName, name, name)
64+
65+
if __name__ == "__main__":
66+
print header
67+
print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
68+
print "from pyspark.ml.param import Param, Params\n\n"
69+
shared = [
70+
("maxIter", "max number of iterations", "100"),
71+
("regParam", "regularization constant", "0.1"),
72+
("featuresCol", "features column name", "'features'"),
73+
("labelCol", "label column name", "'label'"),
74+
("predictionCol", "prediction column name", "'prediction'"),
75+
("inputCol", "input column name", "'input'"),
76+
("outputCol", "output column name", "'output'")]
77+
code = []
78+
for name, doc, defaultValue in shared:
79+
code.append(_gen_param_code(name, doc, defaultValue))
80+
print "\n\n\n".join(code)

python/pyspark/ml/param/shared.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# DO NOT MODIFY. The code is generated by _gen_shared_params.py.
19+
20+
from pyspark.ml.param import Param, Params
21+
22+
23+
class HasMaxIter(Params):
24+
25+
def __init__(self):
26+
super(HasMaxIter, self).__init__()
27+
#: max number of iterations
28+
self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
29+
30+
def setMaxIter(self, value):
31+
self.paramMap[self.maxIter] = value
32+
return self
33+
34+
def getMaxIter(self, value):
35+
if self.maxIter in self.paramMap:
36+
return self.paramMap[self.maxIter]
37+
else:
38+
return self.defaultValue
39+
40+
41+
class HasRegParam(Params):
42+
43+
def __init__(self):
44+
super(HasRegParam, self).__init__()
45+
#: regularization constant
46+
self.regParam = Param(self, "regParam", "regularization constant", 0.1)
47+
48+
def setRegParam(self, value):
49+
self.paramMap[self.regParam] = value
50+
return self
51+
52+
def getRegParam(self, value):
53+
if self.regParam in self.paramMap:
54+
return self.paramMap[self.regParam]
55+
else:
56+
return self.defaultValue
57+
58+
59+
class HasFeaturesCol(Params):
60+
61+
def __init__(self):
62+
super(HasFeaturesCol, self).__init__()
63+
#: features column name
64+
self.featuresCol = Param(self, "featuresCol", "features column name", 'features')
65+
66+
def setFeaturesCol(self, value):
67+
self.paramMap[self.featuresCol] = value
68+
return self
69+
70+
def getFeaturesCol(self, value):
71+
if self.featuresCol in self.paramMap:
72+
return self.paramMap[self.featuresCol]
73+
else:
74+
return self.defaultValue
75+
76+
77+
class HasLabelCol(Params):
78+
79+
def __init__(self):
80+
super(HasLabelCol, self).__init__()
81+
#: label column name
82+
self.labelCol = Param(self, "labelCol", "label column name", 'label')
83+
84+
def setLabelCol(self, value):
85+
self.paramMap[self.labelCol] = value
86+
return self
87+
88+
def getLabelCol(self, value):
89+
if self.labelCol in self.paramMap:
90+
return self.paramMap[self.labelCol]
91+
else:
92+
return self.defaultValue
93+
94+
95+
class HasPredictionCol(Params):
96+
97+
def __init__(self):
98+
super(HasPredictionCol, self).__init__()
99+
#: prediction column name
100+
self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction')
101+
102+
def setPredictionCol(self, value):
103+
self.paramMap[self.predictionCol] = value
104+
return self
105+
106+
def getPredictionCol(self, value):
107+
if self.predictionCol in self.paramMap:
108+
return self.paramMap[self.predictionCol]
109+
else:
110+
return self.defaultValue
111+
112+
113+
class HasInputCol(Params):
114+
115+
def __init__(self):
116+
super(HasInputCol, self).__init__()
117+
#: input column name
118+
self.inputCol = Param(self, "inputCol", "input column name", 'input')
119+
120+
def setInputCol(self, value):
121+
self.paramMap[self.inputCol] = value
122+
return self
123+
124+
def getInputCol(self, value):
125+
if self.inputCol in self.paramMap:
126+
return self.paramMap[self.inputCol]
127+
else:
128+
return self.defaultValue
129+
130+
131+
class HasOutputCol(Params):
132+
133+
def __init__(self):
134+
super(HasOutputCol, self).__init__()
135+
#: output column name
136+
self.outputCol = Param(self, "outputCol", "output column name", 'output')
137+
138+
def setOutputCol(self, value):
139+
self.paramMap[self.outputCol] = value
140+
return self
141+
142+
def getOutputCol(self, value):
143+
if self.outputCol in self.paramMap:
144+
return self.paramMap[self.outputCol]
145+
else:
146+
return self.defaultValue

python/pyspark/ml/util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,9 @@ def __init__(self):
2727
#: A unique id for the object. The default implementation
2828
#: concatenates the class name, "-", and 8 random hex chars.
2929
self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]
30+
31+
def __str__(self):
32+
return self.uid
33+
34+
def __repr__(self):
35+
return str(self)

0 commit comments

Comments
 (0)