Skip to content

Commit a3015cf

Browse files
committed
add Estimator and Transformer
1 parent 46eea43 commit a3015cf

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

python/pyspark/ml/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pyspark import SparkContext
44
from pyspark.ml.param import Param
55

6-
__all__ = ["Pipeline"]
6+
__all__ = ["Pipeline", "Transformer", "Estimator"]
77

88
# An implementation of PEP3102 for Python 2.
99
_keyword_only_secret = 70861589
@@ -60,3 +60,15 @@ def transform(self, dataset):
6060
for t in self.transformers:
6161
dataset = t.transform(dataset)
6262
return dataset
63+
64+
65+
class Estimator(object):
66+
67+
def fit(self, dataset, params={}):
68+
raise NotImplementedError()
69+
70+
71+
class Transformer(object):
72+
73+
def transform(self, dataset, paramMap={}):
74+
raise NotImplementedError()

python/pyspark/ml/classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from pyspark.sql import SchemaRDD
2-
from pyspark.ml import _jvm
2+
from pyspark.ml import Estimator, Transformer, _jvm
33
from pyspark.ml.param import Param
44

55

6-
class LogisticRegression(object):
6+
class LogisticRegression(Estimator):
77
"""
88
Logistic regression.
99
"""
@@ -45,7 +45,7 @@ def fit(self, dataset, params=None):
4545
return LogisticRegressionModel(java_model)
4646

4747

48-
class LogisticRegressionModel(object):
48+
class LogisticRegressionModel(Transformer):
4949
"""
5050
Model fitted by LogisticRegression.
5151
"""

0 commit comments

Comments
 (0)