|
14 | 14 | the decision surface smooth, while a high C aims at classifying |
15 | 15 | all training examples correctly. |
16 | 16 |
|
17 | | -Two plots are generated. The first is a visualization of the |
18 | | -decision function for a variety of parameter values, and the second |
19 | | -is a heatmap of the classifier's cross-validation accuracy as |
20 | | -a function of `C` and `gamma`. |
| 17 | +Two plots are generated. The first is a visualization of the decision function |
| 18 | +for a variety of parameter values, and the second is a heatmap of the |
| 19 | +classifier's cross-validation accuracy and training time as a function of `C` |
| 20 | +and `gamma`. |
| 21 | +
|
| 22 | +An interesting observation on overfitting can be made when comparing validation |
| 23 | +and training error: higher C always result in lower training error, as it |
| 24 | +inceases complexity of the classifier. |
| 25 | +
|
| 26 | +For the validation set on the other hand, there is a tradeoff of goodness of |
| 27 | +fit and generalization. |
| 28 | +
|
| 29 | +We can observe that the lower right half of the parameters (below the diagonal |
| 30 | +with high C and gamma values) is characteristic of parameters that yields an |
| 31 | +overfitting model: the trainin score is very high but there is a wide gap. The |
| 32 | +top and left parts of the parameter plots show underfitting models: the C and |
| 33 | +gamma values can individually or in conjunction constrain the model too much |
| 34 | +leading to low training scores (hence low validation scores too as validation |
| 35 | +scores are on average upper bounded by training scores). |
| 36 | +
|
| 37 | +
|
| 38 | +We can also see that the training time is quite sensitive to the parameter |
| 39 | +setting, while the prediction time is not impacted very much. This is probably |
| 40 | +a consequence of the small size of the data set. |
21 | 41 | ''' |
22 | 42 | print(__doc__) |
23 | 43 |
|
|
65 | 85 | gamma_range = 10.0 ** np.arange(-5, 4) |
66 | 86 | param_grid = dict(gamma=gamma_range, C=C_range) |
67 | 87 | cv = StratifiedKFold(y=Y, n_folds=3) |
68 | | -grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv) |
| 88 | +grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv, |
| 89 | + compute_training_score=True) |
69 | 90 | grid.fit(X, Y) |
70 | 91 |
|
71 | 92 | print("The best classifier is: ", grid.best_estimator_) |
|
108 | 129 | # cv_scores_ contains parameter settings and scores |
109 | 130 | score_dict = grid.cv_scores_ |
110 | 131 |
|
111 | | -# We extract just the scores |
112 | | -scores = [x[1] for x in score_dict] |
113 | | -scores = np.array(scores).reshape(len(C_range), len(gamma_range)) |
114 | | - |
115 | | -# draw heatmap of accuracy as a function of gamma and C |
116 | | -pl.figure(figsize=(8, 6)) |
117 | | -pl.subplots_adjust(left=0.05, right=0.95, bottom=0.15, top=0.95) |
118 | | -pl.imshow(scores, interpolation='nearest', cmap=pl.cm.spectral) |
119 | | -pl.xlabel('gamma') |
120 | | -pl.ylabel('C') |
121 | | -pl.colorbar() |
122 | | -pl.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45) |
123 | | -pl.yticks(np.arange(len(C_range)), C_range) |
| 132 | +# We extract validation and training scores, as well as training and prediction |
| 133 | +# times |
| 134 | +_, val_scores, _, train_scores, train_time, pred_time = zip(*score_dict) |
| 135 | + |
| 136 | +arrays = [val_scores, train_scores, train_time, pred_time] |
| 137 | +titles = ["Validation Score", "Training Score", "Training Time", |
| 138 | + "Prediction Time"] |
| 139 | + |
| 140 | +# for each value draw heatmap as a function of gamma and C |
| 141 | +pl.figure(figsize=(12, 8)) |
| 142 | +for i, (arr, title) in enumerate(zip(arrays, titles)): |
| 143 | + pl.subplot(2, 2, i + 1) |
| 144 | + arr = np.array(arr).reshape(len(C_range), len(gamma_range)) |
| 145 | + pl.title(title) |
| 146 | + pl.imshow(arr, interpolation='nearest', cmap=pl.cm.spectral) |
| 147 | + pl.xlabel('gamma') |
| 148 | + pl.ylabel('C') |
| 149 | + pl.colorbar() |
| 150 | + pl.xticks(np.arange(len(gamma_range)), ["%.e" % g for g in gamma_range], |
| 151 | + rotation=45) |
| 152 | + pl.yticks(np.arange(len(C_range)), ["%.e" % C for C in C_range]) |
| 153 | + |
| 154 | +pl.subplots_adjust(top=.95, hspace=.35, left=.0, right=.8, wspace=.05) |
124 | 155 |
|
125 | 156 | pl.show() |
0 commit comments