Skip to content

Commit df4d263

Browse files
committed
create param HasHandlePersistence
1 parent 936d466 commit df4d263

File tree

8 files changed

+81
-40
lines changed

8 files changed

+81
-40
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ import org.apache.spark.util.VersionUtils
5151
*/
5252
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
5353
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
54-
with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth {
54+
with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth
55+
with HasHandlePersistence {
5556

5657
import org.apache.spark.ml.classification.LogisticRegression.supportedFamilyNames
5758

@@ -431,6 +432,10 @@ class LogisticRegression @Since("1.2.0") (
431432
@Since("2.2.0")
432433
def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)
433434

435+
/** @group setParam */
436+
@Since("2.3.0")
437+
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
438+
434439
private def assertBoundConstrainedOptimizationParamsValid(
435440
numCoefficientSets: Int,
436441
numFeatures: Int): Unit = {
@@ -483,22 +488,15 @@ class LogisticRegression @Since("1.2.0") (
483488
this
484489
}
485490

486-
override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
487-
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
488-
train(dataset, handlePersistence)
489-
}
490-
491-
protected[spark] def train(
492-
dataset: Dataset[_],
493-
handlePersistence: Boolean): LogisticRegressionModel = {
491+
protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
494492
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
495493
val instances: RDD[Instance] =
496494
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
497495
case Row(label: Double, weight: Double, features: Vector) =>
498496
Instance(label, weight, features)
499497
}
500498

501-
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
499+
if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK)
502500

503501
val instr = Instrumentation.create(this, dataset)
504502
instr.logParams(regParam, elasticNetParam, standardization, threshold,
@@ -878,7 +876,7 @@ class LogisticRegression @Since("1.2.0") (
878876
}
879877
}
880878

881-
if (handlePersistence) instances.unpersist()
879+
if ($(handlePersistence)) instances.unpersist()
882880

883881
val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
884882
numClasses, isMultinomial))

mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.ml._
3232
import org.apache.spark.ml.attribute._
3333
import org.apache.spark.ml.linalg.Vector
3434
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
35-
import org.apache.spark.ml.param.shared.HasWeightCol
35+
import org.apache.spark.ml.param.shared.{HasHandlePersistence, HasWeightCol}
3636
import org.apache.spark.ml.util._
3737
import org.apache.spark.sql.{DataFrame, Dataset, Row}
3838
import org.apache.spark.sql.functions._
@@ -53,7 +53,7 @@ private[ml] trait ClassifierTypeTrait {
5353
* Params for [[OneVsRest]].
5454
*/
5555
private[ml] trait OneVsRestParams extends PredictorParams
56-
with ClassifierTypeTrait with HasWeightCol {
56+
with ClassifierTypeTrait with HasWeightCol with HasHandlePersistence {
5757

5858
/**
5959
* param for the base binary classifier that we reduce multiclass classification into.
@@ -65,6 +65,10 @@ private[ml] trait OneVsRestParams extends PredictorParams
6565

6666
/** @group getParam */
6767
def getClassifier: ClassifierType = $(classifier)
68+
69+
/** @group setParam */
70+
@Since("2.3.0")
71+
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
6872
}
6973

7074
private[ml] object OneVsRestParams extends ClassifierTypeTrait {
@@ -161,9 +165,9 @@ final class OneVsRestModel private[ml] (
161165
val initUDF = udf { () => Map[Int, Double]() }
162166
val newDataset = dataset.withColumn(accColName, initUDF())
163167

164-
// persist if underlying dataset is not persistent.
165-
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
166-
if (handlePersistence) newDataset.persist(StorageLevel.MEMORY_AND_DISK)
168+
if ($(handlePersistence)) {
169+
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
170+
}
167171

168172
// update the accumulator column with the result of prediction of models
169173
val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
@@ -186,7 +190,9 @@ final class OneVsRestModel private[ml] (
186190
updatedDataset.select(newColumns: _*).withColumnRenamed(tmpColName, accColName)
187191
}
188192

189-
if (handlePersistence) newDataset.unpersist()
193+
if ($(handlePersistence)) {
194+
newDataset.unpersist()
195+
}
190196

191197
// output the index of the classifier with highest confidence as prediction
192198
val labelUDF = udf { (predictions: Map[Int, Double]) =>
@@ -340,10 +346,9 @@ final class OneVsRest @Since("1.4.0") (
340346
dataset.select($(labelCol), $(featuresCol))
341347
}
342348

343-
// persist if underlying dataset is not persistent.
344-
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
345-
if (handlePersistence) multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
346-
349+
if ($(handlePersistence)) {
350+
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
351+
}
347352

348353
// create k columns, one for each binary classifier.
349354
val models = Range(0, numClasses).par.map { index =>
@@ -367,7 +372,9 @@ final class OneVsRest @Since("1.4.0") (
367372
}.toArray[ClassificationModel[_, _]]
368373
instr.logNumFeatures(models.head.numFeatures)
369374

370-
if (handlePersistence) multiclassLabeled.unpersist()
375+
if ($(handlePersistence)) {
376+
multiclassLabeled.unpersist()
377+
}
371378

372379
// extract label metadata from label column if present, or create a nominal attribute
373380
// to output the number of labels

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
4040
* Common params for KMeans and KMeansModel
4141
*/
4242
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
43-
with HasSeed with HasPredictionCol with HasTol {
43+
with HasSeed with HasPredictionCol with HasTol with HasHandlePersistence {
4444

4545
/**
4646
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
@@ -300,16 +300,21 @@ class KMeans @Since("1.5.0") (
300300
@Since("1.5.0")
301301
def setSeed(value: Long): this.type = set(seed, value)
302302

303+
/** @group setParam */
304+
@Since("2.3.0")
305+
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
306+
303307
@Since("2.0.0")
304308
override def fit(dataset: Dataset[_]): KMeansModel = {
305309
transformSchema(dataset.schema, logging = true)
306310

307-
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
308311
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
309312
case Row(point: Vector) => OldVectors.fromML(point)
310313
}
311314

312-
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
315+
if ($(handlePersistence)) {
316+
instances.persist(StorageLevel.MEMORY_AND_DISK)
317+
}
313318

314319
val instr = Instrumentation.create(this, dataset)
315320
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
@@ -327,7 +332,9 @@ class KMeans @Since("1.5.0") (
327332

328333
model.setSummary(Some(summary))
329334
instr.logSuccess(model)
330-
if (handlePersistence) instances.unpersist()
335+
if ($(handlePersistence)) {
336+
instances.unpersist()
337+
}
331338
model
332339
}
333340

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ private[shared] object SharedParamsCodeGen {
8282
"all instance weights as 1.0"),
8383
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
8484
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
85-
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
85+
isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
86+
ParamDesc[Boolean]("handlePersistence", "whether to handle data persistence", Some("true")))
8687

8788
val code = genSharedParams(params)
8889
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,4 +402,21 @@ private[ml] trait HasAggregationDepth extends Params {
402402
/** @group expertGetParam */
403403
final def getAggregationDepth: Int = $(aggregationDepth)
404404
}
405+
406+
/**
407+
* Trait for shared param handlePersistence (default: true).
408+
*/
409+
private[ml] trait HasHandlePersistence extends Params {
410+
411+
/**
412+
* Param for whether to handle data persistence.
413+
* @group param
414+
*/
415+
final val handlePersistence: BooleanParam = new BooleanParam(this, "handlePersistence", "whether to handle data persistence")
416+
417+
setDefault(handlePersistence, true)
418+
419+
/** @group getParam */
420+
final def getHandlePersistence: Boolean = $(handlePersistence)
421+
}
405422
// scalastyle:on

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ import org.apache.spark.storage.StorageLevel
4646
*/
4747
private[regression] trait AFTSurvivalRegressionParams extends Params
4848
with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
49-
with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
49+
with HasTol with HasFitIntercept with HasAggregationDepth with HasHandlePersistence
50+
with Logging {
5051

5152
/**
5253
* Param for censor column name.
@@ -197,6 +198,10 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
197198
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
198199
setDefault(aggregationDepth -> 2)
199200

201+
/** @group setParam */
202+
@Since("2.3.0")
203+
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
204+
200205
/**
201206
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
202207
* and put it in an RDD with strong types.
@@ -213,8 +218,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
213218
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
214219
transformSchema(dataset.schema, logging = true)
215220
val instances = extractAFTPoints(dataset)
216-
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
217-
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
221+
if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK)
218222

219223
val featuresSummarizer = {
220224
val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
@@ -273,7 +277,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
273277
}
274278

275279
bcFeaturesStd.destroy(blocking = false)
276-
if (handlePersistence) instances.unpersist()
280+
if ($(handlePersistence)) instances.unpersist()
277281

278282
val rawCoefficients = parameters.slice(2, parameters.length)
279283
var i = 0

mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ import org.apache.spark.storage.StorageLevel
3939
* Params for isotonic regression.
4040
*/
4141
private[regression] trait IsotonicRegressionBase extends Params with HasFeaturesCol
42-
with HasLabelCol with HasPredictionCol with HasWeightCol with Logging {
42+
with HasLabelCol with HasPredictionCol with HasWeightCol with HasHandlePersistence
43+
with Logging {
4344

4445
/**
4546
* Param for whether the output sequence should be isotonic/increasing (true) or
@@ -157,6 +158,10 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
157158
@Since("1.5.0")
158159
def setFeatureIndex(value: Int): this.type = set(featureIndex, value)
159160

161+
/** @group setParam */
162+
@Since("2.3.0")
163+
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
164+
160165
@Since("1.5.0")
161166
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
162167

@@ -165,8 +170,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
165170
transformSchema(dataset.schema, logging = true)
166171
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
167172
val instances = extractWeightedLabeledPoints(dataset)
168-
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
169-
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
173+
if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK)
170174

171175
val instr = Instrumentation.create(this, dataset)
172176
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, featureIndex, isotonic)
@@ -175,7 +179,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
175179
val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic))
176180
val oldModel = isotonicRegression.run(instances)
177181

178-
if (handlePersistence) instances.unpersist()
182+
if ($(handlePersistence)) instances.unpersist()
179183

180184
val model = copyValues(new IsotonicRegressionModel(uid, oldModel).setParent(this))
181185
instr.logSuccess(model)

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ import org.apache.spark.storage.StorageLevel
5353
private[regression] trait LinearRegressionParams extends PredictorParams
5454
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
5555
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
56-
with HasAggregationDepth {
56+
with HasAggregationDepth with HasHandlePersistence {
5757

5858
import LinearRegression._
5959

@@ -208,6 +208,10 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
208208
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
209209
setDefault(aggregationDepth -> 2)
210210

211+
/** @group setParam */
212+
@Since("2.3.0")
213+
def setHandlePersistence(value: Boolean): this.type = set(handlePersistence, value)
214+
211215
override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
212216
// Extract the number of features before deciding optimization solver.
213217
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
@@ -251,8 +255,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
251255
return lrModel
252256
}
253257

254-
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
255-
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
258+
if ($(handlePersistence)) instances.persist(StorageLevel.MEMORY_AND_DISK)
256259

257260
val (featuresSummarizer, ySummarizer) = {
258261
val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer),
@@ -285,7 +288,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
285288
s"zeros and the intercept will be the mean of the label; as a result, " +
286289
s"training is not needed.")
287290
}
288-
if (handlePersistence) instances.unpersist()
291+
if ($(handlePersistence)) instances.unpersist()
289292
val coefficients = Vectors.sparse(numFeatures, Seq.empty)
290293
val intercept = yMean
291294

@@ -422,7 +425,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
422425
0.0
423426
}
424427

425-
if (handlePersistence) instances.unpersist()
428+
if ($(handlePersistence)) instances.unpersist()
426429

427430
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
428431
// Handle possible missing or invalid prediction columns

0 commit comments

Comments
 (0)