@@ -80,26 +80,28 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
8080
8181 /**
8282 * Method to validate a gradient boosting model
83- * @param trainInput Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
84- * @param validateInput Validation dataset:
83+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
84+ * @param validationInput Validation dataset:
8585 RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
86- Should follow same distribution as trainInput.
86+ Should be different from and follow the same distribution as input.
87+ e.g., these two datasets could be created from an original dataset
88+ by using [[org.apache.spark.rdd.RDD.randomSplit() ]]
8789 * @return a gradient boosted trees model that can be used for prediction
8890 */
8991 def runWithValidation (
90- trainInput : RDD [LabeledPoint ],
91- validateInput : RDD [LabeledPoint ]): GradientBoostedTreesModel = {
92+ input : RDD [LabeledPoint ],
93+ validationInput : RDD [LabeledPoint ]): GradientBoostedTreesModel = {
9294 val algo = boostingStrategy.treeStrategy.algo
9395 algo match {
9496 case Regression => GradientBoostedTrees .boost(
95- trainInput, validateInput , boostingStrategy, validate= true )
97+ input, validationInput , boostingStrategy, validate= true )
9698 case Classification =>
9799 // Map labels to -1, +1 so binary classification can be treated as regression.
98- val remappedTrainInput = trainInput .map(
100+ val remappedInput = input .map(
99101 x => new LabeledPoint ((x.label * 2 ) - 1 , x.features))
100- val remappedValidateInput = trainInput .map(
102+ val remappedValidationInput = validationInput .map(
101103 x => new LabeledPoint ((x.label * 2 ) - 1 , x.features))
102- GradientBoostedTrees .boost(remappedTrainInput, remappedValidateInput , boostingStrategy,
104+ GradientBoostedTrees .boost(remappedInput, remappedValidationInput , boostingStrategy,
103105 validate= true )
104106 case _ =>
105107 throw new IllegalArgumentException (s " $algo is not supported by the gradient boosting. " )
@@ -110,9 +112,9 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
110112 * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation ]].
111113 */
112114 def runWithValidation (
113- trainInput : JavaRDD [LabeledPoint ],
114- validateInput : JavaRDD [LabeledPoint ]): GradientBoostedTreesModel = {
115- runWithValidation(trainInput .rdd, validateInput .rdd)
115+ input : JavaRDD [LabeledPoint ],
116+ validationInput : JavaRDD [LabeledPoint ]): GradientBoostedTreesModel = {
117+ runWithValidation(input .rdd, validationInput .rdd)
116118 }
117119}
118120
@@ -145,16 +147,16 @@ object GradientBoostedTrees extends Logging {
145147 /**
146148 * Internal method for performing regression using trees as base learners.
147149 * @param input training dataset
148- * @param validateInput validation dataset, ignored if validate is set to false.
150+ * @param validationInput validation dataset, ignored if validate is set to false.
149151 * @param boostingStrategy boosting parameters
150152 * @param validate whether or not to use the validation dataset.
151153 * @return a gradient boosted trees model that can be used for prediction
152154 */
153155 private def boost (
154156 input : RDD [LabeledPoint ],
155- validateInput : RDD [LabeledPoint ],
157+ validationInput : RDD [LabeledPoint ],
156158 boostingStrategy : BoostingStrategy ,
157- validate : Boolean = false ): GradientBoostedTreesModel = {
159+ validate : Boolean ): GradientBoostedTreesModel = {
158160
159161 val timer = new TimeTracker ()
160162 timer.start(" total" )
@@ -198,7 +200,7 @@ object GradientBoostedTrees extends Logging {
198200 // Note: A model of type regression is used since we require raw prediction
199201 timer.stop(" building tree 0" )
200202
201- var bestValidateError = if (validate) loss.computeError(startingModel, validateInput ) else 0.0
203+ var bestValidateError = if (validate) loss.computeError(startingModel, validationInput ) else 0.0
202204 var bestM = 1
203205
204206 // psuedo-residual for second iteration
@@ -225,19 +227,18 @@ object GradientBoostedTrees extends Logging {
225227
226228 if (validate) {
227229 // Stop training early if
228- // 1. Reduction in error is lesser than the validationTol or
230+ // 1. Reduction in error is less than the validationTol or
229231 // 2. If the error increases, that is if the model is overfit.
230232 // We want the model returned corresponding to the best validation error.
231- val currentValidateError = loss.computeError(partialModel, validateInput )
233+ val currentValidateError = loss.computeError(partialModel, validationInput )
232234 if (bestValidateError - currentValidateError < validationTol) {
233235 return new GradientBoostedTreesModel (
234236 boostingStrategy.treeStrategy.algo,
235237 baseLearners.slice(0 , bestM),
236238 baseLearnerWeights.slice(0 , bestM))
237- }
238- else if (currentValidateError < bestValidateError){
239- bestValidateError = currentValidateError
240- bestM = m + 1
239+ } else if (currentValidateError < bestValidateError){
240+ bestValidateError = currentValidateError
241+ bestM = m + 1
241242 }
242243 }
243244 // Update data with pseudo-residuals
0 commit comments