Skip to content

Commit 2ca5a74

Browse files
committed
fix r and python
1 parent 2ffd0eb commit 2ca5a74

File tree

4 files changed

+17
-8
lines changed

4 files changed

+17
-8
lines changed

R/pkg/R/mllib_classification.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
5858
#' @param regParam The regularization parameter. Only supports L2 regularization currently.
5959
#' @param maxIter Maximum iteration number.
6060
#' @param tol Convergence tolerance of iterations.
61+
#' @param solver solver parameter, supported options: "owlqn" or "l-bfgs".
62+
#' @param loss loss function, supported options: "hinge" and "squared_hinge".
6163
#' @param standardization Whether to standardize the training features before fitting the model. The coefficients
6264
#' of models will be always returned on the original scale, so it will be transparent for
6365
#' users. Note that with/without standardization, the models should be always converged
@@ -96,7 +98,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
9698
#' @note spark.svmLinear since 2.2.0
9799
setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"),
98100
function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE,
99-
threshold = 0.0, weightCol = NULL, aggregationDepth = 2) {
101+
threshold = 0.0, weightCol = NULL, aggregationDepth = 2, solver = "l-bfgs",
102+
loss = "squared_hinge") {
100103
formula <- paste(deparse(formula), collapse = "")
101104

102105
if (!is.null(weightCol) && weightCol == "") {
@@ -108,7 +111,8 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
108111
jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit",
109112
data@sdf, formula, as.numeric(regParam), as.integer(maxIter),
110113
as.numeric(tol), as.logical(standardization), as.numeric(threshold),
111-
weightCol, as.integer(aggregationDepth))
114+
weightCol, as.integer(aggregationDepth), as.character(solver),
115+
as.character(loss))
112116
new("LinearSVCModel", jobj = jobj)
113117
})
114118

R/pkg/tests/fulltests/test_mllib_classification.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ absoluteSparkPath <- function(x) {
3030
test_that("spark.svmLinear", {
3131
df <- suppressWarnings(createDataFrame(iris))
3232
training <- df[df$Species %in% c("versicolor", "virginica"), ]
33-
model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10)
33+
model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10,
34+
loss = "hinge", solver = "owlqn")
3435
summary <- summary(model)
3536

3637
# test summary coefficients return matrix type

mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ private[r] object LinearSVCWrapper
7070
val PREDICTED_LABEL_INDEX_COL = "pred_label_idx"
7171
val PREDICTED_LABEL_COL = "prediction"
7272

73-
def fit(
73+
def fit( // scalastyle:ignore
7474
data: DataFrame,
7575
formula: String,
7676
regParam: Double,
@@ -79,7 +79,9 @@ private[r] object LinearSVCWrapper
7979
standardization: Boolean,
8080
threshold: Double,
8181
weightCol: String,
82-
aggregationDepth: Int
82+
aggregationDepth: Int,
83+
solver: String,
84+
loss: String
8385
): LinearSVCWrapper = {
8486

8587
val rFormula = new RFormula()
@@ -105,6 +107,8 @@ private[r] object LinearSVCWrapper
105107
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
106108
.setThreshold(threshold)
107109
.setAggregationDepth(aggregationDepth)
110+
.setSolver(solver)
111+
.setLoss(loss)
108112

109113
if (weightCol != null) svc.setWeightCol(weightCol)
110114

python/pyspark/ml/classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha
8080
>>> svm = LinearSVC(maxIter=5, regParam=0.01)
8181
>>> model = svm.fit(df)
8282
>>> model.coefficients
83-
DenseVector([0.0, -0.2792, -0.1833])
83+
DenseVector([0.0, 0.0759, -0.6167])
8484
>>> model.intercept
85-
1.0206118982229047
85+
1.3113904822325306
8686
>>> model.numClasses
8787
2
8888
>>> model.numFeatures
@@ -92,7 +92,7 @@ class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, Ha
9292
>>> result.prediction
9393
1.0
9494
>>> result.rawPrediction
95-
DenseVector([-1.4831, 1.4831])
95+
DenseVector([-1.8521, 1.8521])
9696
>>> svm_path = temp_path + "/svm"
9797
>>> svm.save(svm_path)
9898
>>> svm2 = LinearSVC.load(svm_path)

0 commit comments

Comments
 (0)