Skip to content

Commit 0c94b55

Browse files
committed
ENH add randomized hyperparameter optimization
1 parent db6f005 commit 0c94b55

File tree

10 files changed

+622
-167
lines changed

10 files changed

+622
-167
lines changed

doc/modules/classes.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,9 @@ From text
455455
:template: class.rst
456456

457457
grid_search.GridSearchCV
458-
grid_search.IterGrid
458+
grid_search.ParameterGrid
459+
grid_search.ParameterSampler
460+
grid_search.RandomizedSearchCV
459461

460462

461463
.. _hmm_ref:

doc/modules/grid_search.rst

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
.. _grid_search:
22

3+
.. currentmodule:: sklearn.grid_search
4+
35
==========================================
46
Grid Search: setting estimator parameters
57
==========================================
68

7-
.. currentmodule:: sklearn
8-
99
Grid Search is used to optimize the parameters of a model (e.g. ``C``,
1010
``kernel`` and ``gamma`` for Support Vector Classifier, ``alpha`` for
1111
Lasso, etc.) using an internal :ref:`cross_validation` scheme).
@@ -15,10 +15,10 @@ GridSearchCV
1515
============
1616

1717
The main class for implementing hyperparameters grid search in
18-
scikit-learn is :class:`grid_search.GridSearchCV`. This class is passed
18+
scikit-learn is :class:`GridSearchCV`. This class is passed
1919
a base model instance (for example ``sklearn.svm.SVC()``) along with a
2020
grid of potential hyper-parameter values specified with the `param_grid`
21-
attribute. For instace the following `param_grid`::
21+
attribute. For instance the following `param_grid`::
2222

2323
param_grid = [
2424
{'C': [1, 10, 100, 1000], 'kernel': ['linear']},
@@ -30,7 +30,7 @@ C values in [1, 10, 100, 1000], and the second one with an RBG kernel,
3030
and the cross-product of C values ranging in [1, 10, 100, 1000] and gamma
3131
values in [0.001, 0.0001].
3232

33-
The :class:`grid_search.GridSearchCV` instance implements the usual
33+
The :class:`GridSearchCV` instance implements the usual
3434
estimator API: when "fitting" it on a dataset all the possible
3535
combinations of hyperparameter values are evaluated and the best
3636
combinations is retained.
@@ -64,24 +64,76 @@ alternative scoring function can be specified via the ``scoring`` parameter to
6464
:class:`GridSearchCV`.
6565
See :ref:`score_func_objects` for more details.
6666

67-
Examples
68-
========
67+
.. topic:: Examples:
6968

70-
- See :ref:`example_grid_search_digits.py` for an example of
71-
Grid Search computation on the digits dataset.
69+
- See :ref:`example_grid_search_digits.py` for an example of
70+
Grid Search computation on the digits dataset.
7271

73-
- See :ref:`example_grid_search_text_feature_extraction.py` for an example
74-
of Grid Search coupling parameters from a text documents feature
75-
extractor (n-gram count vectorizer and TF-IDF transformer) with a
76-
classifier (here a linear SVM trained with SGD with either elastic
77-
net or L2 penalty) using a :class:`pipeline.Pipeline` instance.
72+
- See :ref:`example_grid_search_text_feature_extraction.py` for an example
73+
of Grid Search coupling parameters from a text documents feature
74+
extractor (n-gram count vectorizer and TF-IDF transformer) with a
75+
classifier (here a linear SVM trained with SGD with either elastic
76+
net or L2 penalty) using a :class:`pipeline.Pipeline` instance.
7877

7978
.. note::
8079

8180
Computations can be run in parallel if your OS supports it, by using
8281
the keyword n_jobs=-1, see function signature for more details.
8382

8483

84+
Randomized Hyper-Parameter Optimization
85+
=======================================
86+
While using a grid of parameter settings is currently the most widely used
87+
method for hyper-parameter optimization, other search methods have more
88+
favourable properties.
89+
:class:`RandomizedSearchCV` implements a randomized search over hyperparameters,
90+
where each setting is sampled from a distribution over possible parameter values.
91+
This has two main benefits over searching over a grid:
92+
93+
* A budget can be chosen independent of the number of parameters and possible values.
94+
95+
* Adding parameters that do not influence the performance does not decrease efficiency.
96+
97+
Specifying how parameters should be sampled is done using a dictionary, very
98+
similar to specifying parameters for :class:`GridSearchCV`. Additionally,
99+
a computation budget is specified using ``n_iter``, which is the number
100+
of iterations (parameter samples) to be used.
101+
For each parameter, either a distribution over possible values or list of
102+
discrete choices (which will be sampled uniformly) can be specified::
103+
104+
[{'C': scipy.stats.expon(scale=100), 'gamma': scipy.stats.expon(scale=.1),
105+
'kernel': ['rbf'], 'class_weight':['auto', None]}]
106+
107+
This example uses the ``scipy.stats`` module, which contains many useful
108+
distributions for sampling hyperparameters, such as ``expon``, ``gamma``,
109+
``uniform`` or ``randint``.
110+
In principle, any function can be passed that provides a ``rvs`` (random
111+
variate sample) method to sample a value. A call to the ``rvs`` function should
112+
provide independent random samples from possible parameter values on
113+
consecutive calls.
114+
115+
.. warning::
116+
117+
The distributions in ``scipy.stats`` do not allow specifying a random
118+
state. Instead, they use the global numpy random state, that can be seeded
119+
via ``np.random.seed`` or set using ``np.random.set_state``.
120+
121+
For continuous parameters, such as ``C`` above, it is important to specify
122+
a continuous distribution to take full advantage of the randomization. This way,
123+
increasing ``n_iter`` will always lead to a finer search.
124+
125+
.. topic:: Examples:
126+
127+
* :ref:`example_randomized_search.py` compares the usage and efficiency
128+
of randomized search and grid search.
129+
130+
.. topic:: References:
131+
132+
* Bergstra, J. and Bengio, Y.,
133+
Random search for hyper-parameter optimization,
134+
The Journal of Machine Learning Research (2012)
135+
136+
85137
Alternatives to brute force grid search
86138
=======================================
87139

doc/tutorial/statistical_inference/model_selection.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ estimator during the construction and exposes an estimator API::
146146
>>> clf.fit(X_digits[:1000], y_digits[:1000]) # doctest: +ELLIPSIS
147147
GridSearchCV(cv=None,...
148148
>>> clf.best_score_
149-
0.988991985997974
149+
0.98899999999999999
150150
>>> clf.best_estimator_.gamma
151151
9.9999999999999995e-07
152152

doc/whats_new.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ Changelog
3535
attribute. Setting ``compute_importances=True`` is no longer required.
3636
By `Gilles Louppe`_.
3737

38+
- Added :class:`grid_search.RandomizedSearchCV` and
39+
:class:`grid_search.ParameterSampler` for randomized hyperparameter
40+
optimization. By `Andreas Müller`_.
41+
3842
- :class:`LinearSVC`, :class:`SGDClassifier` and :class:`SGDRegressor`
3943
now have a ``sparsify`` method that converts their ``coef_`` into a
4044
sparse matrix, meaning stored models trained using these estimators
@@ -46,6 +50,13 @@ Changelog
4650
- Fixed bug in :class:`MinMaxScaler` causing incorrect scaling of the
4751
features for non-default ``feature_range`` settings. By `Andreas Müller`_.
4852

53+
54+
API changes summary
55+
-------------------
56+
57+
- :class:`grid_search.IterGrid` was renamed to
58+
:class:`grid_search.ParameterGrid`.
59+
4960
- Fixed bug in :class:`KFold` causing imperfect class balance in some
5061
cases. By `Alexandre Gramfort`_ and Tadej Janež.
5162

examples/grid_search_digits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
print()
6060
print("Grid scores on development set:")
6161
print()
62-
for params, mean_score, scores in clf.grid_scores_:
62+
for params, mean_score, scores in clf.cv_scores_:
6363
print("%0.3f (+/-%0.03f) for %r"
6464
% (mean_score, scores.std() / 2, params))
6565
print()

examples/svm/plot_rbf_parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@
105105
pl.axis('tight')
106106

107107
# plot the scores of the grid
108-
# grid_scores_ contains parameter settings and scores
109-
score_dict = grid.grid_scores_
108+
# cv_scores_ contains parameter settings and scores
109+
score_dict = grid.cv_scores_
110110

111111
# We extract just the scores
112112
scores = [x[1] for x in score_dict]

examples/svm/plot_svm_scale_c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
cv=ShuffleSplit(n=n_samples, train_size=train_size,
132132
n_iter=250, random_state=1))
133133
grid.fit(X, y)
134-
scores = [x[1] for x in grid.grid_scores_]
134+
scores = [x[1] for x in grid.cv_scores_]
135135

136136
scales = [(1, 'No scaling'),
137137
((n_samples * train_size), '1/n_samples'),

0 commit comments

Comments
 (0)