Skip to content

Commit d29fd67

Browse files
committed
add threshold param to ALS
1 parent dc4f4ba commit d29fd67

File tree

2 files changed

+39
-10
lines changed
  • mllib/src

2 files changed

+39
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,24 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
200200
/** @group expertGetParam */
201201
def getFinalStorageLevel: String = $(finalStorageLevel)
202202

203+
/**
204+
* Param for threshold in computation of dst factors to decide
205+
* if stacking factors to speed up the computation.(>= 1).
206+
* Default: 1024
207+
* @group expertParam
208+
*/
209+
val threshold = new IntParam(this, "threshold", "threshold in computation of dst factors " +
210+
"to decide if stacking factors to speed up the computation.",
211+
ParamValidators.gtEq(1))
212+
213+
/** @group expertGetParam */
214+
def getThreshold: Int = $(threshold)
215+
203216
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
204217
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
205218
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
206-
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
219+
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
220+
threshold -> 1024)
207221

208222
/**
209223
* Validates and transforms the input schema.
@@ -436,6 +450,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
436450
@Since("2.0.0")
437451
def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)
438452

453+
/** @group expertSetParam */
454+
@Since("2.1.0")
455+
def setThreshold(value: Int): this.type = set(threshold, value)
456+
439457
/**
440458
* Sets both numUserBlocks and numItemBlocks to the specific value.
441459
*
@@ -464,14 +482,15 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
464482
val instrLog = Instrumentation.create(this, ratings)
465483
instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
466484
userCol, itemCol, ratingCol, predictionCol, maxIter,
467-
regParam, nonnegative, checkpointInterval, seed)
485+
regParam, nonnegative, threshold, checkpointInterval, seed)
468486
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
469487
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
470488
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
471489
alpha = $(alpha), nonnegative = $(nonnegative),
472490
intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)),
473491
finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)),
474-
checkpointInterval = $(checkpointInterval), seed = $(seed))
492+
threshold = $(threshold), checkpointInterval = $(checkpointInterval),
493+
seed = $(seed))
475494
val userDF = userFactors.toDF("id", "features")
476495
val itemDF = itemFactors.toDF("id", "features")
477496
val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
@@ -706,6 +725,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
706725
nonnegative: Boolean = false,
707726
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
708727
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
728+
threshold: Int = 1024,
709729
checkpointInterval: Int = 10,
710730
seed: Long = 0L)(
711731
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
@@ -752,7 +772,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
752772
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
753773
val previousItemFactors = itemFactors
754774
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
755-
userLocalIndexEncoder, implicitPrefs, alpha, solver)
775+
userLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
756776
previousItemFactors.unpersist()
757777
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
758778
// TODO: Generalize PeriodicGraphCheckpointer and use it here.
@@ -762,7 +782,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
762782
}
763783
val previousUserFactors = userFactors
764784
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
765-
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
785+
itemLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
766786
if (shouldCheckpoint(iter)) {
767787
ALS.cleanShuffleDependencies(sc, deps)
768788
deletePreviousCheckpointFile()
@@ -773,7 +793,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
773793
} else {
774794
for (iter <- 0 until maxIter) {
775795
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
776-
userLocalIndexEncoder, solver = solver)
796+
userLocalIndexEncoder, solver = solver, threshold = threshold)
777797
if (shouldCheckpoint(iter)) {
778798
val deps = itemFactors.dependencies
779799
itemFactors.checkpoint()
@@ -783,7 +803,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
783803
previousCheckpointFile = itemFactors.getCheckpointFile
784804
}
785805
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
786-
itemLocalIndexEncoder, solver = solver)
806+
itemLocalIndexEncoder, solver = solver, threshold = threshold)
787807
}
788808
}
789809
val userIdAndFactors = userInBlocks
@@ -1297,7 +1317,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
12971317
srcEncoder: LocalIndexEncoder,
12981318
implicitPrefs: Boolean = false,
12991319
alpha: Double = 1.0,
1300-
solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
1320+
solver: LeastSquaresNESolver,
1321+
threshold: Int): RDD[(Int, FactorBlock)] = {
13011322
val numSrcBlocks = srcFactorBlocks.partitions.length
13021323
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
13031324
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
@@ -1325,7 +1346,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
13251346
var numExplicits = 0
13261347
// Stacking factors(vectors) in matrices to speed up the computation,
13271348
// when the number of factors and the rank is large enough.
1328-
val doStack = srcPtrs(j + 1) - srcPtrs(j) > 1024 && rank > 1024
1349+
val doStack = srcPtrs(j + 1) - srcPtrs(j) > threshold && rank > threshold
13291350
val srcFactorBuffer = mutable.ArrayBuilder.make[Double]
13301351
val bBuffer = mutable.ArrayBuilder.make[Double]
13311352
while (i < srcPtrs(j + 1)) {

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ class ALSSuite
301301
implicitPrefs: Boolean = false,
302302
numUserBlocks: Int = 2,
303303
numItemBlocks: Int = 3,
304-
targetRMSE: Double = 0.05): Unit = {
304+
targetRMSE: Double = 0.05,
305+
threshold: Int = 1024): Unit = {
305306
val spark = this.spark
306307
import spark.implicits._
307308
val als = new ALS()
@@ -311,6 +312,7 @@ class ALSSuite
311312
.setNumUserBlocks(numUserBlocks)
312313
.setNumItemBlocks(numItemBlocks)
313314
.setSeed(0)
315+
.setThreshold(threshold)
314316
val alpha = als.getAlpha
315317
val model = als.fit(training.toDF())
316318
val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map {
@@ -382,6 +384,12 @@ class ALSSuite
382384
numItemBlocks = 5, numUserBlocks = 5)
383385
}
384386

387+
test("do stacking factors in matrices") {
388+
val (training, test) = genExplicitTestData(numUsers = 200, numItems = 20, rank = 1)
389+
testALS(training, test, maxIter = 1, rank = 129, regParam = 0.01, targetRMSE = 0.02,
390+
threshold = 128)
391+
}
392+
385393
test("implicit feedback") {
386394
val (training, test) =
387395
genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)

0 commit comments

Comments
 (0)