Skip to content

Commit e4d799b

Browse files
committed
Addresses indentation and doc comments
1 parent b48a70f commit e4d799b

File tree

4 files changed

+43
-45
lines changed

4 files changed

+43
-45
lines changed

docs/mllib-ensembles.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -429,16 +429,15 @@ We omit some decision tree parameters since those are covered in the [decision t
429429

430430
#### Validation while training
431431

432-
Gradient boosting can overfit when trained with more number of trees. In order to prevent overfitting, it might
433-
be useful to validate while training. The method **`runWithValidation`** has been provided to make use of this
434-
option. It takes a pair of RDD's as arguments, the first one being the training dataset and the second being the validation dataset.
432+
Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while
433+
training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD's as arguments, the
434+
first one being the training dataset and the second being the validation dataset.
435435

436436
The training is stopped when the improvement in the validation error is not more than a certain tolerance
437-
(supplied by the **`validationTol`** argument in **`BoostingStrategy`**). In practice, the validation error
438-
decreases with the increase in number of trees and then increases as the model starts to overfit. There might
439-
be cases, in which the validation error does not change monotonically, and the user is advised to set a large
440-
enough negative tolerance and examine the validation curve to make further inference.
441-
437+
(supplied by the validationTol argument in BoostingStrategy). In practice, the validation error
438+
decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
439+
and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
440+
iterations.
442441

443442
### Examples
444443

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,26 +80,28 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
8080

8181
/**
8282
* Method to validate a gradient boosting model
83-
* @param trainInput Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
84-
* @param validateInput Validation dataset:
83+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
84+
* @param validationInput Validation dataset:
8585
RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
86-
Should follow same distribution as trainInput.
86+
Should be different from and follow the same distribution as input.
87+
e.g., these two datasets could be created from an original dataset
88+
by using [[org.apache.spark.rdd.RDD.randomSplit()]]
8789
* @return a gradient boosted trees model that can be used for prediction
8890
*/
8991
def runWithValidation(
90-
trainInput: RDD[LabeledPoint],
91-
validateInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
92+
input: RDD[LabeledPoint],
93+
validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
9294
val algo = boostingStrategy.treeStrategy.algo
9395
algo match {
9496
case Regression => GradientBoostedTrees.boost(
95-
trainInput, validateInput, boostingStrategy, validate=true)
97+
input, validationInput, boostingStrategy, validate=true)
9698
case Classification =>
9799
// Map labels to -1, +1 so binary classification can be treated as regression.
98-
val remappedTrainInput = trainInput.map(
100+
val remappedInput = input.map(
99101
x => new LabeledPoint((x.label * 2) - 1, x.features))
100-
val remappedValidateInput = trainInput.map(
102+
val remappedValidationInput = validationInput.map(
101103
x => new LabeledPoint((x.label * 2) - 1, x.features))
102-
GradientBoostedTrees.boost(remappedTrainInput, remappedValidateInput, boostingStrategy,
104+
GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
103105
validate=true)
104106
case _ =>
105107
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
@@ -110,9 +112,9 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
110112
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]].
111113
*/
112114
def runWithValidation(
113-
trainInput: JavaRDD[LabeledPoint],
114-
validateInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
115-
runWithValidation(trainInput.rdd, validateInput.rdd)
115+
input: JavaRDD[LabeledPoint],
116+
validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
117+
runWithValidation(input.rdd, validationInput.rdd)
116118
}
117119
}
118120

@@ -145,16 +147,16 @@ object GradientBoostedTrees extends Logging {
145147
/**
146148
* Internal method for performing regression using trees as base learners.
147149
* @param input training dataset
148-
* @param validateInput validation dataset, ignored if validate is set to false.
150+
* @param validationInput validation dataset, ignored if validate is set to false.
149151
* @param boostingStrategy boosting parameters
150152
* @param validate whether or not to use the validation dataset.
151153
* @return a gradient boosted trees model that can be used for prediction
152154
*/
153155
private def boost(
154156
input: RDD[LabeledPoint],
155-
validateInput: RDD[LabeledPoint],
157+
validationInput: RDD[LabeledPoint],
156158
boostingStrategy: BoostingStrategy,
157-
validate: Boolean = false): GradientBoostedTreesModel = {
159+
validate: Boolean): GradientBoostedTreesModel = {
158160

159161
val timer = new TimeTracker()
160162
timer.start("total")
@@ -198,7 +200,7 @@ object GradientBoostedTrees extends Logging {
198200
// Note: A model of type regression is used since we require raw prediction
199201
timer.stop("building tree 0")
200202

201-
var bestValidateError = if (validate) loss.computeError(startingModel, validateInput) else 0.0
203+
var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
202204
var bestM = 1
203205

204206
// psuedo-residual for second iteration
@@ -225,19 +227,18 @@ object GradientBoostedTrees extends Logging {
225227

226228
if (validate) {
227229
// Stop training early if
228-
// 1. Reduction in error is lesser than the validationTol or
230+
// 1. Reduction in error is less than the validationTol or
229231
// 2. If the error increases, that is if the model is overfit.
230232
// We want the model returned corresponding to the best validation error.
231-
val currentValidateError = loss.computeError(partialModel, validateInput)
233+
val currentValidateError = loss.computeError(partialModel, validationInput)
232234
if (bestValidateError - currentValidateError < validationTol) {
233235
return new GradientBoostedTreesModel(
234236
boostingStrategy.treeStrategy.algo,
235237
baseLearners.slice(0, bestM),
236238
baseLearnerWeights.slice(0, bestM))
237-
}
238-
else if (currentValidateError < bestValidateError){
239-
bestValidateError = currentValidateError
240-
bestM = m + 1
239+
} else if (currentValidateError < bestValidateError){
240+
bestValidateError = currentValidateError
241+
bestM = m + 1
241242
}
242243
}
243244
// Update data with pseudo-residuals

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,11 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
3434
* weak hypotheses used in the final model.
3535
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
3636
* learning rate should be between in the interval (0, 1]
37-
* @param validationTol Useful when runWithValidation is used. If the error rate between two
38-
iterations is lesser than the validationTol, then stop. If run
39-
is used, then this parameter is ignored.
40-
41-
a pair of RDD's are supplied to run. If the error rate
42-
* between two iterations is lesser than convergenceTol, then training stops.
37+
* @param validationTol Useful when runWithValidation is used. If the error rate on the
38+
* validation input between two iterations is less than the validationTol
39+
* then stop. Ignored when [[run]] is used.
4340
*/
41+
4442
@Experimental
4543
case class BoostingStrategy(
4644
// Required boosting parameters

mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
160160
}
161161

162162
test("runWithValidation performs better on a validation dataset (Regression)") {
163-
// Set numIterations large enough so that it early stops.
163+
// Set numIterations large enough so that it stops early.
164164
val numIterations = 20
165165
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
166166
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
@@ -171,9 +171,9 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
171171
val boostingStrategy =
172172
new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0)
173173

174-
val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation(
175-
trainRdd, validateRdd)
176-
assert(gbtValidate.numTrees != numIterations)
174+
val gbtValidate = new GradientBoostedTrees(boostingStrategy).
175+
runWithValidation(trainRdd, validateRdd)
176+
assert(gbtValidate.numTrees !== numIterations)
177177

178178
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
179179
val errorWithoutValidation = error.computeError(gbt, validateRdd)
@@ -183,7 +183,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
183183
}
184184

185185
test("runWithValidation performs better on a validation dataset (Classification)") {
186-
// Set numIterations large enough so that it early stops.
186+
// Set numIterations large enough so that it stops early.
187187
val numIterations = 20
188188
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
189189
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
@@ -194,9 +194,9 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
194194
new BoostingStrategy(treeStrategy, LogLoss, numIterations, validationTol = 0.0)
195195

196196
// Test that it stops early.
197-
val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation(
198-
trainRdd, validateRdd)
199-
assert(gbtValidate.numTrees != numIterations)
197+
val gbtValidate = new GradientBoostedTrees(boostingStrategy).
198+
runWithValidation(trainRdd, validateRdd)
199+
assert(gbtValidate.numTrees !== numIterations)
200200

201201
// Remap labels to {-1, 1}
202202
val remappedInput = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -213,7 +213,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
213213
val errorWithoutValidation = LogLoss.computeError(gbtRegressor, remappedInput)
214214

215215
assert(errorWithValidation < errorWithoutValidation)
216-
}
216+
}
217217

218218
}
219219

0 commit comments

Comments
 (0)