Skip to content

Commit d46e5ed

Browse files
committed
add lbfgs as default optimizer of LinearSVC
1 parent 13eb37c commit d46e5ed

File tree

2 files changed

+98
-30
lines changed

2 files changed

+98
-30
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
package org.apache.spark.ml.classification
1919

20+
import java.util.Locale
21+
2022
import scala.collection.mutable
2123

2224
import breeze.linalg.{DenseVector => BDV}
23-
import breeze.optimize.{CachedDiffFunction, DiffFunction, OWLQN => BreezeOWLQN}
25+
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
26+
OWLQN => BreezeOWLQN}
2427
import org.apache.hadoop.fs.Path
2528

2629
import org.apache.spark.SparkException
@@ -42,7 +45,21 @@ import org.apache.spark.sql.functions.{col, lit}
4245
/** Params for linear SVM Classifier. */
4346
private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
4447
with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
45-
with HasThreshold with HasAggregationDepth
48+
with HasThreshold with HasAggregationDepth {
49+
50+
/**
51+
* The optimization algorithm for LinearSVC.
52+
* Supported options: "lbfgs" and "owlqn".
53+
* (default: "lbfgs")
54+
* @group param
55+
*/
56+
final val optimizer: Param[String] = new Param[String](this, "optimizer", "The optimization" +
57+
" algorithm to be used", ParamValidators.inArray[String](LinearSVC.supportedOptimizers))
58+
59+
/** @group getParam */
60+
final def getOptimizer: String = $(optimizer)
61+
62+
}
4663

4764
/**
4865
* :: Experimental ::
@@ -60,6 +77,8 @@ class LinearSVC @Since("2.2.0") (
6077
extends Classifier[Vector, LinearSVC, LinearSVCModel]
6178
with LinearSVCParams with DefaultParamsWritable {
6279

80+
import LinearSVC._
81+
6382
@Since("2.2.0")
6483
def this() = this(Identifiable.randomUID("linearsvc"))
6584

@@ -145,6 +164,15 @@ class LinearSVC @Since("2.2.0") (
145164
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
146165
setDefault(aggregationDepth -> 2)
147166

167+
/**
168+
* Set optimizer for LinearSVC. Supported options: "lbfgs" and "owlqn".
169+
*
170+
* @group setParam
171+
*/
172+
@Since("2.2.0")
173+
def setOptimizer(value: String): this.type = set(optimizer, value.toLowerCase(Locale.ROOT))
174+
setDefault(optimizer -> "lbfgs")
175+
148176
@Since("2.2.0")
149177
override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
150178

@@ -205,15 +233,21 @@ class LinearSVC @Since("2.2.0") (
205233
val costFun = new LinearSVCCostFun(instances, $(fitIntercept),
206234
$(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth))
207235

208-
def regParamL1Fun = (index: Int) => 0D
209-
val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
236+
val optimizerAlgo = $(optimizer) match {
237+
case LBFGS => new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
238+
case OWLQN =>
239+
def regParamL1Fun = (index: Int) => 0D
240+
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
241+
case _ => throw new SparkException ("unexpected optimizer: " + $(optimizer))
242+
}
243+
210244
val initialCoefWithIntercept = Vectors.zeros(numFeaturesPlusIntercept)
211245

212-
val states = optimizer.iterations(new CachedDiffFunction(costFun),
246+
val states = optimizerAlgo.iterations(new CachedDiffFunction(costFun),
213247
initialCoefWithIntercept.asBreeze.toDenseVector)
214248

215249
val scaledObjectiveHistory = mutable.ArrayBuilder.make[Double]
216-
var state: optimizer.State = null
250+
var state: optimizerAlgo.State = null
217251
while (states.hasNext) {
218252
state = states.next()
219253
scaledObjectiveHistory += state.adjustedValue
@@ -258,6 +292,15 @@ class LinearSVC @Since("2.2.0") (
258292
@Since("2.2.0")
259293
object LinearSVC extends DefaultParamsReadable[LinearSVC] {
260294

295+
/** String name for Limited-memory BFGS. */
296+
private[classification] val LBFGS: String = "lbfgs".toLowerCase(Locale.ROOT)
297+
298+
/** String name for Orthant-Wise Limited-memory Quasi-Newton. */
299+
private[classification] val OWLQN: String = "owlqn".toLowerCase(Locale.ROOT)
300+
301+
/* Set of optimizers that LinearSVC supports */
302+
private[classification] val supportedOptimizers = Array(LBFGS, OWLQN)
303+
261304
@Since("2.2.0")
262305
override def load(path: String): LinearSVC = super.load(path)
263306
}

mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,25 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
7575
}
7676

7777
test("Linear SVC binary classification") {
78-
val svm = new LinearSVC()
79-
val model = svm.fit(smallBinaryDataset)
80-
assert(model.transform(smallValidationDataset)
81-
.where("prediction=label").count() > nPoints * 0.8)
82-
val sparseModel = svm.fit(smallSparseBinaryDataset)
83-
checkModels(model, sparseModel)
78+
LinearSVC.supportedOptimizers.foreach { opt =>
79+
val svm = new LinearSVC().setOptimizer(opt)
80+
val model = svm.fit(smallBinaryDataset)
81+
assert(model.transform(smallValidationDataset)
82+
.where("prediction=label").count() > nPoints * 0.8)
83+
val sparseModel = svm.fit(smallSparseBinaryDataset)
84+
checkModels(model, sparseModel)
85+
}
8486
}
8587

8688
test("Linear SVC binary classification with regularization") {
87-
val svm = new LinearSVC()
88-
val model = svm.setRegParam(0.1).fit(smallBinaryDataset)
89-
assert(model.transform(smallValidationDataset)
90-
.where("prediction=label").count() > nPoints * 0.8)
91-
val sparseModel = svm.fit(smallSparseBinaryDataset)
92-
checkModels(model, sparseModel)
89+
LinearSVC.supportedOptimizers.foreach { opt =>
90+
val svm = new LinearSVC().setOptimizer(opt).setMaxIter(10)
91+
val model = svm.setRegParam(0.1).fit(smallBinaryDataset)
92+
assert(model.transform(smallValidationDataset)
93+
.where("prediction=label").count() > nPoints * 0.8)
94+
val sparseModel = svm.fit(smallSparseBinaryDataset)
95+
checkModels(model, sparseModel)
96+
}
9397
}
9498

9599
test("params") {
@@ -112,6 +116,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
112116
assert(lsvc.getFeaturesCol === "features")
113117
assert(lsvc.getPredictionCol === "prediction")
114118
assert(lsvc.getRawPredictionCol === "rawPrediction")
119+
assert(lsvc.getOptimizer === "lbfgs")
115120
val model = lsvc.setMaxIter(5).fit(smallBinaryDataset)
116121
model.transform(smallBinaryDataset)
117122
.select("label", "prediction", "rawPrediction")
@@ -154,22 +159,23 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
154159

155160
test("linearSVC with sample weights") {
156161
def modelEquals(m1: LinearSVCModel, m2: LinearSVCModel): Unit = {
157-
assert(m1.coefficients ~== m2.coefficients absTol 0.05)
162+
assert(m1.coefficients ~== m2.coefficients absTol 0.07)
158163
assert(m1.intercept ~== m2.intercept absTol 0.05)
159164
}
160-
161-
val estimator = new LinearSVC().setRegParam(0.01).setTol(0.01)
162-
val dataset = smallBinaryDataset
163-
MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC](
164-
dataset.as[LabeledPoint], estimator, modelEquals)
165-
MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC](
166-
dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3)
167-
MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC](
168-
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
165+
LinearSVC.supportedOptimizers.foreach { opt =>
166+
val estimator = new LinearSVC().setRegParam(0.02).setTol(0.01).setOptimizer(opt)
167+
val dataset = smallBinaryDataset
168+
MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC](
169+
dataset.as[LabeledPoint], estimator, modelEquals)
170+
MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC](
171+
dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3)
172+
MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC](
173+
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
174+
}
169175
}
170176

171-
test("linearSVC comparison with R e1071 and scikit-learn") {
172-
val trainer1 = new LinearSVC()
177+
test("linearSVC OWLQN comparison with R e1071 and scikit-learn") {
178+
val trainer1 = new LinearSVC().setOptimizer("owlqn")
173179
.setRegParam(0.00002) // set regParam = 2.0 / datasize / c
174180
.setMaxIter(200)
175181
.setTol(1e-4)
@@ -223,6 +229,25 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
223229
assert(model1.coefficients ~== coefficientsSK relTol 4E-3)
224230
}
225231

232+
test("linearSVC LBFGS comparison with R e1071 and scikit-learn") {
233+
val trainer1 = new LinearSVC().setOptimizer("LBFGS")
234+
.setRegParam(0.00003)
235+
.setMaxIter(200)
236+
.setTol(1e-4)
237+
val model1 = trainer1.fit(binaryDataset)
238+
239+
// refer to last unit test for R and python code
240+
val coefficientsR = Vectors.dense(7.310338, 14.89741, 22.21005, 29.83508)
241+
val interceptR = 7.440177
242+
assert(model1.intercept ~== interceptR relTol 2E-2)
243+
assert(model1.coefficients ~== coefficientsR relTol 1E-2)
244+
245+
val coefficientsSK = Vectors.dense(7.24690165, 14.77029087, 21.99924004, 29.5575729)
246+
val interceptSK = 7.36947518
247+
assert(model1.intercept ~== interceptSK relTol 1E-2)
248+
assert(model1.coefficients ~== coefficientsSK relTol 1E-2)
249+
}
250+
226251
test("read/write: SVM") {
227252
def checkModelData(model: LinearSVCModel, model2: LinearSVCModel): Unit = {
228253
assert(model.intercept === model2.intercept)

0 commit comments

Comments
 (0)