Skip to content
This repository was archived by the owner on Nov 15, 2024. It is now read-only.

Commit a5a3189

Browse files
facaiyyanboliang
authored andcommitted
[SPARK-21306][ML] OneVsRest should support setWeightCol
## What changes were proposed in this pull request? add `setWeightCol` method for OneVsRest. `weightCol` is ignored if classifier doesn't inherit HasWeightCol trait. ## How was this patch tested? + [x] add an unit test. Author: Yan Facai (颜发才) <[email protected]> Closes apache#18554 from facaiy/BUG/oneVsRest_missing_weightCol.
1 parent f44ead8 commit a5a3189

File tree

4 files changed

+81
-9
lines changed

4 files changed

+81
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.ml._
3434
import org.apache.spark.ml.attribute._
3535
import org.apache.spark.ml.linalg.Vector
3636
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
37+
import org.apache.spark.ml.param.shared.HasWeightCol
3738
import org.apache.spark.ml.util._
3839
import org.apache.spark.sql.{DataFrame, Dataset, Row}
3940
import org.apache.spark.sql.functions._
@@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
5354
/**
5455
* Params for [[OneVsRest]].
5556
*/
56-
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
57+
private[ml] trait OneVsRestParams extends PredictorParams
58+
with ClassifierTypeTrait with HasWeightCol {
5759

5860
/**
5961
* param for the base binary classifier that we reduce multiclass classification into.
@@ -294,6 +296,18 @@ final class OneVsRest @Since("1.4.0") (
294296
@Since("1.5.0")
295297
def setPredictionCol(value: String): this.type = set(predictionCol, value)
296298

299+
/**
300+
* Sets the value of param [[weightCol]].
301+
*
302+
* This is ignored if weight is not supported by [[classifier]].
303+
* If this is not set or empty, we treat all instance weights as 1.0.
304+
* Default is not set, so all instances have weight one.
305+
*
306+
* @group setParam
307+
*/
308+
@Since("2.3.0")
309+
def setWeightCol(value: String): this.type = set(weightCol, value)
310+
297311
@Since("1.4.0")
298312
override def transformSchema(schema: StructType): StructType = {
299313
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
@@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") (
317331
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
318332
instr.logNumClasses(numClasses)
319333

320-
val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
334+
val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
335+
getClassifier match {
336+
case _: HasWeightCol => true
337+
case c =>
338+
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
339+
false
340+
}
341+
}
342+
343+
val multiclassLabeled = if (weightColIsUsed) {
344+
dataset.select($(labelCol), $(featuresCol), $(weightCol))
345+
} else {
346+
dataset.select($(labelCol), $(featuresCol))
347+
}
321348

322349
// persist if underlying dataset is not persistent.
323350
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") (
337364
paramMap.put(classifier.labelCol -> labelColName)
338365
paramMap.put(classifier.featuresCol -> getFeaturesCol)
339366
paramMap.put(classifier.predictionCol -> getPredictionCol)
340-
classifier.fit(trainingDataset, paramMap)
367+
if (weightColIsUsed) {
368+
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
369+
paramMap.put(classifier_.weightCol -> getWeightCol)
370+
classifier_.fit(trainingDataset, paramMap)
371+
} else {
372+
classifier.fit(trainingDataset, paramMap)
373+
}
341374
}.toArray[ClassificationModel[_, _]]
342375
instr.logNumFeatures(models.head.numFeatures)
343376

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
156156
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
157157
}
158158

159+
test("SPARK-21306: OneVsRest should support setWeightCol") {
160+
val dataset2 = dataset.withColumn("weight", lit(1))
161+
// classifier inherits hasWeightCol
162+
val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
163+
assert(ova.fit(dataset2) !== null)
164+
// classifier doesn't inherit hasWeightCol
165+
val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier())
166+
assert(ova2.fit(dataset2) !== null)
167+
}
168+
159169
test("OneVsRest.copy and OneVsRestModel.copy") {
160170
val lr = new LogisticRegression()
161171
.setMaxIter(1)

python/pyspark/ml/classification.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,7 +1447,7 @@ def weights(self):
14471447
return self._call_java("weights")
14481448

14491449

1450-
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
1450+
class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol):
14511451
"""
14521452
Parameters for OneVsRest and OneVsRestModel.
14531453
"""
@@ -1517,20 +1517,22 @@ class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable):
15171517

15181518
@keyword_only
15191519
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
1520-
classifier=None):
1520+
classifier=None, weightCol=None):
15211521
"""
15221522
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
1523-
classifier=None)
1523+
classifier=None, weightCol=None)
15241524
"""
15251525
super(OneVsRest, self).__init__()
15261526
kwargs = self._input_kwargs
15271527
self._set(**kwargs)
15281528

15291529
@keyword_only
15301530
@since("2.0.0")
1531-
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
1531+
def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
1532+
classifier=None, weightCol=None):
15321533
"""
1533-
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
1534+
setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
1535+
classifier=None, weightCol=None):
15341536
Sets params for OneVsRest.
15351537
"""
15361538
kwargs = self._input_kwargs
@@ -1546,7 +1548,18 @@ def _fit(self, dataset):
15461548

15471549
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
15481550

1549-
multiclassLabeled = dataset.select(labelCol, featuresCol)
1551+
weightCol = None
1552+
if (self.isDefined(self.weightCol) and self.getWeightCol()):
1553+
if isinstance(classifier, HasWeightCol):
1554+
weightCol = self.getWeightCol()
1555+
else:
1556+
warnings.warn("weightCol is ignored, "
1557+
"as it is not supported by {} now.".format(classifier))
1558+
1559+
if weightCol:
1560+
multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol)
1561+
else:
1562+
multiclassLabeled = dataset.select(labelCol, featuresCol)
15501563

15511564
# persist if underlying dataset is not persistent.
15521565
handlePersistence = \
@@ -1562,6 +1575,8 @@ def trainSingleClass(index):
15621575
paramMap = dict([(classifier.labelCol, binaryLabelCol),
15631576
(classifier.featuresCol, featuresCol),
15641577
(classifier.predictionCol, predictionCol)])
1578+
if weightCol:
1579+
paramMap[classifier.weightCol] = weightCol
15651580
return classifier.fit(trainingDataset, paramMap)
15661581

15671582
# TODO: Parallel training for all classes.

python/pyspark/ml/tests.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,20 @@ def test_output_columns(self):
13941394
output = model.transform(df)
13951395
self.assertEqual(output.columns, ["label", "features", "prediction"])
13961396

1397+
def test_support_for_weightCol(self):
1398+
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
1399+
(1.0, Vectors.sparse(2, [], []), 1.0),
1400+
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
1401+
["label", "features", "weight"])
1402+
# classifier inherits hasWeightCol
1403+
lr = LogisticRegression(maxIter=5, regParam=0.01)
1404+
ovr = OneVsRest(classifier=lr, weightCol="weight")
1405+
self.assertIsNotNone(ovr.fit(df))
1406+
# classifier doesn't inherit hasWeightCol
1407+
dt = DecisionTreeClassifier()
1408+
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
1409+
self.assertIsNotNone(ovr2.fit(df))
1410+
13971411

13981412
class HashingTFTest(SparkSessionTestCase):
13991413

0 commit comments

Comments
 (0)