Skip to content

Commit 971e52c

Browse files
committed
fix
1 parent df4d263 commit 971e52c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,13 +444,13 @@ class LogisticRegressionWithLBFGS
444444
lr.setFitIntercept(addIntercept)
445445
lr.setMaxIter(optimizer.getNumIterations())
446446
lr.setTol(optimizer.getConvergenceTol())
447+
// Determine if we should cache the DF
448+
lr.setHandlePersistence(input.getStorageLevel == StorageLevel.NONE)
447449
// Convert our input into a DataFrame
448450
val spark = SparkSession.builder().sparkContext(input.context).getOrCreate()
449451
val df = spark.createDataFrame(input.map(_.asML))
450-
// Determine if we should cache the DF
451-
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
452452
// Train our model
453-
val mlLogisticRegressionModel = lr.train(df, handlePersistence)
453+
val mlLogisticRegressionModel = lr.train(df)
454454
// convert the model
455455
val weights = Vectors.dense(mlLogisticRegressionModel.coefficients.toArray)
456456
createModel(weights, mlLogisticRegressionModel.intercept)

0 commit comments

Comments
 (0)