@@ -75,21 +75,25 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
7575 }
7676
7777 test(" Linear SVC binary classification" ) {
78- val svm = new LinearSVC ()
79- val model = svm.fit(smallBinaryDataset)
80- assert(model.transform(smallValidationDataset)
81- .where(" prediction=label" ).count() > nPoints * 0.8 )
82- val sparseModel = svm.fit(smallSparseBinaryDataset)
83- checkModels(model, sparseModel)
78+ LinearSVC .supportedOptimizers.foreach { opt =>
79+ val svm = new LinearSVC ().setOptimizer(opt)
80+ val model = svm.fit(smallBinaryDataset)
81+ assert(model.transform(smallValidationDataset)
82+ .where(" prediction=label" ).count() > nPoints * 0.8 )
83+ val sparseModel = svm.fit(smallSparseBinaryDataset)
84+ checkModels(model, sparseModel)
85+ }
8486 }
8587
8688 test(" Linear SVC binary classification with regularization" ) {
87- val svm = new LinearSVC ()
88- val model = svm.setRegParam(0.1 ).fit(smallBinaryDataset)
89- assert(model.transform(smallValidationDataset)
90- .where(" prediction=label" ).count() > nPoints * 0.8 )
91- val sparseModel = svm.fit(smallSparseBinaryDataset)
92- checkModels(model, sparseModel)
89+ LinearSVC .supportedOptimizers.foreach { opt =>
90+ val svm = new LinearSVC ().setOptimizer(opt).setMaxIter(10 )
91+ val model = svm.setRegParam(0.1 ).fit(smallBinaryDataset)
92+ assert(model.transform(smallValidationDataset)
93+ .where(" prediction=label" ).count() > nPoints * 0.8 )
94+ val sparseModel = svm.fit(smallSparseBinaryDataset)
95+ checkModels(model, sparseModel)
96+ }
9397 }
9498
9599 test(" params" ) {
@@ -112,6 +116,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
112116 assert(lsvc.getFeaturesCol === " features" )
113117 assert(lsvc.getPredictionCol === " prediction" )
114118 assert(lsvc.getRawPredictionCol === " rawPrediction" )
119+ assert(lsvc.getOptimizer === " lbfgs" )
115120 val model = lsvc.setMaxIter(5 ).fit(smallBinaryDataset)
116121 model.transform(smallBinaryDataset)
117122 .select(" label" , " prediction" , " rawPrediction" )
@@ -154,22 +159,23 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
154159
155160 test(" linearSVC with sample weights" ) {
156161 def modelEquals (m1 : LinearSVCModel , m2 : LinearSVCModel ): Unit = {
157- assert(m1.coefficients ~== m2.coefficients absTol 0.05 )
162+ assert(m1.coefficients ~== m2.coefficients absTol 0.07 )
158163 assert(m1.intercept ~== m2.intercept absTol 0.05 )
159164 }
160-
161- val estimator = new LinearSVC ().setRegParam(0.01 ).setTol(0.01 )
162- val dataset = smallBinaryDataset
163- MLTestingUtils .testArbitrarilyScaledWeights[LinearSVCModel , LinearSVC ](
164- dataset.as[LabeledPoint ], estimator, modelEquals)
165- MLTestingUtils .testOutliersWithSmallWeights[LinearSVCModel , LinearSVC ](
166- dataset.as[LabeledPoint ], estimator, 2 , modelEquals, outlierRatio = 3 )
167- MLTestingUtils .testOversamplingVsWeighting[LinearSVCModel , LinearSVC ](
168- dataset.as[LabeledPoint ], estimator, modelEquals, 42L )
165+ LinearSVC .supportedOptimizers.foreach { opt =>
166+ val estimator = new LinearSVC ().setRegParam(0.02 ).setTol(0.01 ).setOptimizer(opt)
167+ val dataset = smallBinaryDataset
168+ MLTestingUtils .testArbitrarilyScaledWeights[LinearSVCModel , LinearSVC ](
169+ dataset.as[LabeledPoint ], estimator, modelEquals)
170+ MLTestingUtils .testOutliersWithSmallWeights[LinearSVCModel , LinearSVC ](
171+ dataset.as[LabeledPoint ], estimator, 2 , modelEquals, outlierRatio = 3 )
172+ MLTestingUtils .testOversamplingVsWeighting[LinearSVCModel , LinearSVC ](
173+ dataset.as[LabeledPoint ], estimator, modelEquals, 42L )
174+ }
169175 }
170176
171- test(" linearSVC comparison with R e1071 and scikit-learn" ) {
172- val trainer1 = new LinearSVC ()
177+ test(" linearSVC OWLQN comparison with R e1071 and scikit-learn" ) {
178+ val trainer1 = new LinearSVC ().setOptimizer( " owlqn " )
173179 .setRegParam(0.00002 ) // set regParam = 2.0 / datasize / c
174180 .setMaxIter(200 )
175181 .setTol(1e-4 )
@@ -223,6 +229,25 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
223229 assert(model1.coefficients ~== coefficientsSK relTol 4E-3 )
224230 }
225231
232+ test(" linearSVC LBFGS comparison with R e1071 and scikit-learn" ) {
233+ val trainer1 = new LinearSVC ().setOptimizer(" LBFGS" )
234+ .setRegParam(0.00003 )
235+ .setMaxIter(200 )
236+ .setTol(1e-4 )
237+ val model1 = trainer1.fit(binaryDataset)
238+
239+ // refer to last unit test for R and python code
240+ val coefficientsR = Vectors .dense(7.310338 , 14.89741 , 22.21005 , 29.83508 )
241+ val interceptR = 7.440177
242+ assert(model1.intercept ~== interceptR relTol 2E-2 )
243+ assert(model1.coefficients ~== coefficientsR relTol 1E-2 )
244+
245+ val coefficientsSK = Vectors .dense(7.24690165 , 14.77029087 , 21.99924004 , 29.5575729 )
246+ val interceptSK = 7.36947518
247+ assert(model1.intercept ~== interceptSK relTol 1E-2 )
248+ assert(model1.coefficients ~== coefficientsSK relTol 1E-2 )
249+ }
250+
226251 test(" read/write: SVM" ) {
227252 def checkModelData (model : LinearSVCModel , model2 : LinearSVCModel ): Unit = {
228253 assert(model.intercept === model2.intercept)
0 commit comments