Skip to content

Commit 9ba0e2b

Browse files
committed
TST: classifier doesn't have weightCol
1 parent 54e0fca commit 9ba0e2b

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,12 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
158158

159159
test("SPARK-21306: OneVsRest should support setWeightCol") {
160160
val dataset2 = dataset.withColumn("weight", lit(1))
161+
// classifier inherits hasWeightCol
161162
val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression())
162-
val ovaModel = ova.fit(dataset2)
163-
assert(ovaModel !== null)
163+
assert(ova.fit(dataset2) !== null)
164+
// classifier doesn't inherit hasWeightCol
165+
val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier())
166+
assert(ova2.fit(dataset2) !== null)
164167
}
165168

166169
test("OneVsRest.copy and OneVsRestModel.copy") {

python/pyspark/ml/tests.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,11 +1260,18 @@ def test_support_for_weightCol(self):
12601260
(1.0, Vectors.sparse(2, [], []), 1.0),
12611261
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
12621262
["label", "features", "weight"])
1263+
# classifier inherits hasWeightCol
12631264
lr = LogisticRegression(maxIter=5, regParam=0.01)
12641265
ovr = OneVsRest(classifier=lr, weightCol="weight")
12651266
self.assertIsNotNone(ovr.fit(df))
12661267
ovr2 = OneVsRest(classifier=lr).setWeightCol("weight")
12671268
self.assertIsNotNone(ovr2.fit(df))
1269+
# classifier doesn't inherit hasWeightCol
1270+
dt = DecisionTreeClassifier()
1271+
ovr3 = OneVsRest(classifier=dt, weightCol="weight")
1272+
self.assertIsNotNone(ovr3.fit(df))
1273+
ovr4 = OneVsRest(classifier=dt).setWeightCol("weight")
1274+
self.assertIsNotNone(ovr4.fit(df))
12681275

12691276

12701277
class HashingTFTest(SparkSessionTestCase):

0 commit comments

Comments
 (0)