@@ -200,10 +200,24 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
200200 /** @group expertGetParam */
201201 def getFinalStorageLevel : String = $(finalStorageLevel)
202202
203+ /**
204+ * Param for threshold in computation of dst factors to decide
205+ * if stacking factors to speed up the computation.(>= 1).
206+ * Default: 1024
207+ * @group expertParam
208+ */
209+ val threshold = new IntParam (this , " threshold" , " threshold in computation of dst factors " +
210+ " to decide if stacking factors to speed up the computation." ,
211+ ParamValidators .gtEq(1 ))
212+
213+ /** @group expertGetParam */
214+ def getThreshold : Int = $(threshold)
215+
203216 setDefault(rank -> 10 , maxIter -> 10 , regParam -> 0.1 , numUserBlocks -> 10 , numItemBlocks -> 10 ,
204217 implicitPrefs -> false , alpha -> 1.0 , userCol -> " user" , itemCol -> " item" ,
205218 ratingCol -> " rating" , nonnegative -> false , checkpointInterval -> 10 ,
206- intermediateStorageLevel -> " MEMORY_AND_DISK" , finalStorageLevel -> " MEMORY_AND_DISK" )
219+ intermediateStorageLevel -> " MEMORY_AND_DISK" , finalStorageLevel -> " MEMORY_AND_DISK" ,
220+ threshold -> 1024 )
207221
208222 /**
209223 * Validates and transforms the input schema.
@@ -436,6 +450,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
436450 @ Since (" 2.0.0" )
437451 def setFinalStorageLevel (value : String ): this .type = set(finalStorageLevel, value)
438452
453+ /** @group expertSetParam */
454+ @ Since (" 2.1.0" )
455+ def setThreshold (value : Int ): this .type = set(threshold, value)
456+
439457 /**
440458 * Sets both numUserBlocks and numItemBlocks to the specific value.
441459 *
@@ -464,14 +482,15 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
464482 val instrLog = Instrumentation .create(this , ratings)
465483 instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
466484 userCol, itemCol, ratingCol, predictionCol, maxIter,
467- regParam, nonnegative, checkpointInterval, seed)
485+ regParam, nonnegative, threshold, checkpointInterval, seed)
468486 val (userFactors, itemFactors) = ALS .train(ratings, rank = $(rank),
469487 numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
470488 maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
471489 alpha = $(alpha), nonnegative = $(nonnegative),
472490 intermediateRDDStorageLevel = StorageLevel .fromString($(intermediateStorageLevel)),
473491 finalRDDStorageLevel = StorageLevel .fromString($(finalStorageLevel)),
474- checkpointInterval = $(checkpointInterval), seed = $(seed))
492+ threshold = $(threshold), checkpointInterval = $(checkpointInterval),
493+ seed = $(seed))
475494 val userDF = userFactors.toDF(" id" , " features" )
476495 val itemDF = itemFactors.toDF(" id" , " features" )
477496 val model = new ALSModel (uid, $(rank), userDF, itemDF).setParent(this )
@@ -706,6 +725,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
706725 nonnegative : Boolean = false ,
707726 intermediateRDDStorageLevel : StorageLevel = StorageLevel .MEMORY_AND_DISK ,
708727 finalRDDStorageLevel : StorageLevel = StorageLevel .MEMORY_AND_DISK ,
728+ threshold : Int = 1024 ,
709729 checkpointInterval : Int = 10 ,
710730 seed : Long = 0L )(
711731 implicit ord : Ordering [ID ]): (RDD [(ID , Array [Float ])], RDD [(ID , Array [Float ])]) = {
@@ -752,7 +772,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
752772 userFactors.setName(s " userFactors- $iter" ).persist(intermediateRDDStorageLevel)
753773 val previousItemFactors = itemFactors
754774 itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
755- userLocalIndexEncoder, implicitPrefs, alpha, solver)
775+ userLocalIndexEncoder, implicitPrefs, alpha, solver, threshold )
756776 previousItemFactors.unpersist()
757777 itemFactors.setName(s " itemFactors- $iter" ).persist(intermediateRDDStorageLevel)
758778 // TODO: Generalize PeriodicGraphCheckpointer and use it here.
@@ -762,7 +782,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
762782 }
763783 val previousUserFactors = userFactors
764784 userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
765- itemLocalIndexEncoder, implicitPrefs, alpha, solver)
785+ itemLocalIndexEncoder, implicitPrefs, alpha, solver, threshold )
766786 if (shouldCheckpoint(iter)) {
767787 ALS .cleanShuffleDependencies(sc, deps)
768788 deletePreviousCheckpointFile()
@@ -773,7 +793,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
773793 } else {
774794 for (iter <- 0 until maxIter) {
775795 itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
776- userLocalIndexEncoder, solver = solver)
796+ userLocalIndexEncoder, solver = solver, threshold = threshold )
777797 if (shouldCheckpoint(iter)) {
778798 val deps = itemFactors.dependencies
779799 itemFactors.checkpoint()
@@ -783,7 +803,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
783803 previousCheckpointFile = itemFactors.getCheckpointFile
784804 }
785805 userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
786- itemLocalIndexEncoder, solver = solver)
806+ itemLocalIndexEncoder, solver = solver, threshold = threshold )
787807 }
788808 }
789809 val userIdAndFactors = userInBlocks
@@ -1297,7 +1317,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
12971317 srcEncoder : LocalIndexEncoder ,
12981318 implicitPrefs : Boolean = false ,
12991319 alpha : Double = 1.0 ,
1300- solver : LeastSquaresNESolver ): RDD [(Int , FactorBlock )] = {
1320+ solver : LeastSquaresNESolver ,
1321+ threshold : Int ): RDD [(Int , FactorBlock )] = {
13011322 val numSrcBlocks = srcFactorBlocks.partitions.length
13021323 val YtY = if (implicitPrefs) Some (computeYtY(srcFactorBlocks, rank)) else None
13031324 val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
@@ -1325,7 +1346,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
13251346 var numExplicits = 0
13261347 // Stacking factors(vectors) in matrices to speed up the computation,
13271348 // when the number of factors and the rank is large enough.
1328- val doStack = srcPtrs(j + 1 ) - srcPtrs(j) > 1024 && rank > 1024
1349+ val doStack = srcPtrs(j + 1 ) - srcPtrs(j) > threshold && rank > threshold
13291350 val srcFactorBuffer = mutable.ArrayBuilder .make[Double ]
13301351 val bBuffer = mutable.ArrayBuilder .make[Double ]
13311352 while (i < srcPtrs(j + 1 )) {
0 commit comments