Skip to content

Commit d884180

Browse files
committed
ENH add training score to GridSearchCV.cv_scores_
add docstring for GridSearchCV, RandomizedSearchCV and fit_grid_point. In "fit_grid_point" I used test_score rather than validation_score, as the split is given to the function. rbf svm grid search example now also shows training scores - which illustrates overfitting for high C, and training/prediction times... which pasically serve to illustrate that this is possible. Maybe random forests would be better to evaluate training times?
1 parent d2254e4 commit d884180

File tree

4 files changed

+183
-71
lines changed

4 files changed

+183
-71
lines changed

doc/tutorial/statistical_inference/model_selection.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ estimator during the construction and exposes an estimator API::
144144
>>> clf = GridSearchCV(estimator=svc, param_grid=dict(gamma=gammas),
145145
... n_jobs=-1)
146146
>>> clf.fit(X_digits[:1000], y_digits[:1000]) # doctest: +ELLIPSIS
147-
GridSearchCV(cv=None,...
147+
GridSearchCV(compute_training_score=False,...
148148
>>> clf.best_score_
149149
0.98899999999999999
150150
>>> clf.best_estimator_.gamma

examples/svm/plot_rbf_parameters.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,30 @@
1414
the decision surface smooth, while a high C aims at classifying
1515
all training examples correctly.
1616
17-
Two plots are generated. The first is a visualization of the
18-
decision function for a variety of parameter values, and the second
19-
is a heatmap of the classifier's cross-validation accuracy as
20-
a function of `C` and `gamma`.
17+
Two plots are generated. The first is a visualization of the decision function
18+
for a variety of parameter values, and the second is a heatmap of the
19+
classifier's cross-validation accuracy and training time as a function of `C`
20+
and `gamma`.
21+
22+
An interesting observation on overfitting can be made when comparing validation
23+
and training error: higher C always result in lower training error, as it
24+
inceases complexity of the classifier.
25+
26+
For the validation set on the other hand, there is a tradeoff of goodness of
27+
fit and generalization.
28+
29+
We can observe that the lower right half of the parameters (below the diagonal
30+
with high C and gamma values) is characteristic of parameters that yields an
31+
overfitting model: the trainin score is very high but there is a wide gap. The
32+
top and left parts of the parameter plots show underfitting models: the C and
33+
gamma values can individually or in conjunction constrain the model too much
34+
leading to low training scores (hence low validation scores too as validation
35+
scores are on average upper bounded by training scores).
36+
37+
38+
We can also see that the training time is quite sensitive to the parameter
39+
setting, while the prediction time is not impacted very much. This is probably
40+
a consequence of the small size of the data set.
2141
'''
2242
print(__doc__)
2343

@@ -65,7 +85,8 @@
6585
gamma_range = 10.0 ** np.arange(-5, 4)
6686
param_grid = dict(gamma=gamma_range, C=C_range)
6787
cv = StratifiedKFold(y=Y, n_folds=3)
68-
grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)
88+
grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv,
89+
compute_training_score=True)
6990
grid.fit(X, Y)
7091

7192
print("The best classifier is: ", grid.best_estimator_)
@@ -108,18 +129,28 @@
108129
# cv_scores_ contains parameter settings and scores
109130
score_dict = grid.cv_scores_
110131

111-
# We extract just the scores
112-
scores = [x[1] for x in score_dict]
113-
scores = np.array(scores).reshape(len(C_range), len(gamma_range))
114-
115-
# draw heatmap of accuracy as a function of gamma and C
116-
pl.figure(figsize=(8, 6))
117-
pl.subplots_adjust(left=0.05, right=0.95, bottom=0.15, top=0.95)
118-
pl.imshow(scores, interpolation='nearest', cmap=pl.cm.spectral)
119-
pl.xlabel('gamma')
120-
pl.ylabel('C')
121-
pl.colorbar()
122-
pl.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
123-
pl.yticks(np.arange(len(C_range)), C_range)
132+
# We extract validation and training scores, as well as training and prediction
133+
# times
134+
_, val_scores, _, train_scores, train_time, pred_time = zip(*score_dict)
135+
136+
arrays = [val_scores, train_scores, train_time, pred_time]
137+
titles = ["Validation Score", "Training Score", "Training Time",
138+
"Prediction Time"]
139+
140+
# for each value draw heatmap as a function of gamma and C
141+
pl.figure(figsize=(12, 8))
142+
for i, (arr, title) in enumerate(zip(arrays, titles)):
143+
pl.subplot(2, 2, i + 1)
144+
arr = np.array(arr).reshape(len(C_range), len(gamma_range))
145+
pl.title(title)
146+
pl.imshow(arr, interpolation='nearest', cmap=pl.cm.spectral)
147+
pl.xlabel('gamma')
148+
pl.ylabel('C')
149+
pl.colorbar()
150+
pl.xticks(np.arange(len(gamma_range)), ["%.e" % g for g in gamma_range],
151+
rotation=45)
152+
pl.yticks(np.arange(len(C_range)), ["%.e" % C for C in C_range])
153+
154+
pl.subplots_adjust(top=.95, hspace=.35, left=.0, right=.8, wspace=.05)
124155

125156
pl.show()

0 commit comments

Comments
 (0)