Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9f9b6b0
[MRG] FIX remove duplicated code and handle 0 seed for our_rand_r (#5…
Jun 15, 2018
322434f
Don’t share random state between threads
ClemDoum Mar 7, 2019
fe23a0c
Udpate unit tests
ClemDoum Mar 8, 2019
2015b20
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
e114908
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
9a7efad
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
a1e950e
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
88ebe79
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
7d24faa
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
313ae17
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
4b0fecd
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
a384247
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
171d863
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
a388265
Add dirty prints to check seeds and numpy max on difference platforms
ClemDoum Mar 12, 2019
799b9d9
Enrich sequential dataset tests to check they work cross platform
ClemDoum Mar 13, 2019
5916742
Run only some tests in the CI
ClemDoum Mar 13, 2019
3efdcb0
test our_rand_r
ClemDoum Mar 13, 2019
6908be0
test our_rand_r
ClemDoum Mar 13, 2019
d0a1142
test our_rand_r
ClemDoum Mar 13, 2019
448979c
test our_rand_r
ClemDoum Mar 13, 2019
2706cc3
test our_rand_r
ClemDoum Mar 13, 2019
07ecd49
test our_rand_r
ClemDoum Mar 13, 2019
3d66b32
test our_rand_r
ClemDoum Mar 13, 2019
c150127
test our_rand_r
ClemDoum Mar 14, 2019
d8b9de6
test our_rand_r
ClemDoum Mar 14, 2019
3ab4b92
test our_rand_r
ClemDoum Mar 14, 2019
a080586
test our_rand_r
ClemDoum Mar 14, 2019
6ad0ab0
test our_rand_r
ClemDoum Mar 14, 2019
d909470
Merge remote-tracking branch 'jdnc/fix5015' into fix/sgd-random-state
ClemDoum Mar 14, 2019
31093ce
Fix, factorize and test our_rand_r
ClemDoum Mar 14, 2019
7cf9a84
_cast_py, _our_rand_r_py should stay private
ClemDoum Mar 14, 2019
5a62b3b
Test our_rand_r with 0 seed
ClemDoum Mar 14, 2019
cb09fab
Fix _cast_py docstring
ClemDoum Mar 14, 2019
8b46c88
Update changelog
ClemDoum Mar 20, 2019
f97cb0c
Update Doctest
ClemDoum Mar 20, 2019
35aac38
Improve documentation
ClemDoum Mar 21, 2019
a4197da
Remove test_cython_cast
ClemDoum Mar 21, 2019
3b404b5
Improve documentation
ClemDoum Mar 21, 2019
39e211d
Remove unnecessary parentheses
ClemDoum Mar 21, 2019
7c58525
our_rand_r should be inline
ClemDoum Mar 21, 2019
bae07e6
Remove useless default seed
ClemDoum Mar 22, 2019
522ea52
Fix _random.pxd
ClemDoum Mar 22, 2019
2fcb780
Fix sphinx gallery version
ClemDoum Mar 22, 2019
a0a1efa
Merge branch 'master' into fix/sgd-random-state
ClemDoum Mar 22, 2019
3657c27
Revert "Remove useless default seed"
ClemDoum Mar 22, 2019
93b5d09
Merge branch 'master' into fix/sgd-random-state
ClemDoum Mar 27, 2019
5bdf6c8
Merge branch 'master' into fix/sgd-random-state
jnothman Apr 4, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build_tools/circle/build_doc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ conda create -n $CONDA_ENV_NAME --yes --quiet python="${PYTHON_VERSION:-*}" \
joblib

source activate testenv
pip install sphinx-gallery
pip install "sphinx-gallery>=0.2,<0.3"
pip install numpydoc==0.8

# Build and install scikit-learn in dev mode
Expand Down
23 changes: 11 additions & 12 deletions doc/tutorial/text_analytics/working_with_text_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ classifier object into our pipeline::
Pipeline(...)
>>> predicted = text_clf.predict(docs_test)
>>> np.mean(predicted == twenty_test.target) # doctest: +ELLIPSIS
0.9127...
0.9101...

We achieved 91.3% accuracy using the SVM. ``scikit-learn`` provides further
utilities for more detailed performance analysis of the results::
Expand All @@ -378,22 +378,21 @@ utilities for more detailed performance analysis of the results::
... # doctest: +NORMALIZE_WHITESPACE
precision recall f1-score support
<BLANKLINE>
alt.atheism 0.95 0.81 0.87 319
comp.graphics 0.88 0.97 0.92 389
sci.med 0.94 0.90 0.92 396
alt.atheism 0.95 0.80 0.87 319
comp.graphics 0.87 0.98 0.92 389
sci.med 0.94 0.89 0.91 396
soc.religion.christian 0.90 0.95 0.93 398
<BLANKLINE>
accuracy 0.91 1502
macro avg 0.92 0.91 0.91 1502
weighted avg 0.92 0.91 0.91 1502
macro avg 0.91 0.91 0.91 1502
weighted avg 0.91 0.91 0.91 1502
<BLANKLINE>

>>> metrics.confusion_matrix(twenty_test.target, predicted)
array([[258, 11, 15, 35],
[ 4, 379, 3, 3],
[ 5, 33, 355, 3],
[ 5, 10, 4, 379]])

array([[256, 11, 16, 36],
[ 4, 380, 3, 2],
[ 5, 35, 353, 3],
[ 5, 11, 4, 378]])

As expected the confusion matrix shows that posts from the newsgroups
on atheism and Christianity are more often confused for one another than
Expand Down Expand Up @@ -471,7 +470,7 @@ mean score and the parameters setting corresponding to that score::
...
clf__alpha: 0.001
tfidf__use_idf: True
vect__ngram_range: (1, 2)
vect__ngram_range: (1, 1)

A more detailed summary of the search is available at ``gs_clf.cv_results_``.

Expand Down
20 changes: 17 additions & 3 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ random sampling procedures.
- :class:`neural_network.MLPClassifier` |Fix|
- :func:`svm.SVC.decision_function` and
:func:`multiclass.OneVsOneClassifier.decision_function`. |Fix|
- :class:`linear_model.SGDClassifier` and any derived classifiers. |Fix|
- Any model using the :func:`linear_model.sag.sag_solver` function with a `0`
seed, including :class:`linear_model.LogisticRegression`,
:class:`linear_model.LogisticRegressionCV`, :class:`linear_model.Ridge`,
and :class:`linear_model.RidgeCV` with 'sag' solver. |Fix|


Details are listed in the changelog below.
Expand Down Expand Up @@ -304,6 +309,11 @@ Support for Python 3.4 and below has been officially dropped.
in version 0.21 and will be removed in version 0.23.
:issue:`12821` by :user:`Nicolas Hug <NicolasHug>`.

- |Fix| Fixed a bug in
:class:`linear_model.stochastic_gradient.BaseSGDClassifier` that was not
deterministic when trained in a multi-class setting on several threads.
:issue:`13422` by :user:`Clément Doumouro <ClemDoum>`.

:mod:`sklearn.manifold`
............................

Expand Down Expand Up @@ -544,15 +554,19 @@ Multiple modules
but this can be altered with the ``print_changed_only`` option in
:func:`sklearn.set_config`. :issue:`11705` by :user:`Nicolas Hug
<NicolasHug>`.
- |Efficiency| Memory copies are avoided when casting arrays to a different
dtype in multiple estimators. :issue:`11973` by :user:`Roman Yurchak
<rth>`.
- |MajorFeature| Add estimators tags: these are annotations of estimators
that allow programmatic inspection of their capabilities, such as sparse
matrix support, supported output types and supported methods. Estimator
tags also determine the tests that are run on an estimator when
`check_estimator` is called. Read more in the :ref:`User Guide
<estimator_tags>`. :issue:`8022` by :user:`Andreas Müller <amueller>`.
- |Efficiency| Memory copies are avoided when casting arrays to a different
dtype in multiple estimators. :issue:`11973` by :user:`Roman Yurchak
<rth>`.
- |Fix| Fixed a bug in the implementation of the :func:`our_rand_r`
helper function that was not behaving consistently across platforms.
:issue:`13422` by :user:`Madhura Parikh <jdnc>` and
:user:`Clément Doumouro <ClemDoum>`.

Changes to estimator checks
---------------------------
Expand Down
2 changes: 1 addition & 1 deletion sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ class AdditiveChi2Sampler(BaseEstimator, TransformerMixin):
random_state=0, shuffle=True, tol=0.001, validation_fraction=0.1,
verbose=0, warm_start=False)
>>> clf.score(X_transformed, y) # doctest: +ELLIPSIS
0.9543...
0.9499...

Notes
-----
Expand Down
12 changes: 3 additions & 9 deletions sklearn/linear_model/cd_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ from ..utils._cython_blas cimport (_axpy, _dot, _asum, _ger, _gemv, _nrm2,
from ..utils._cython_blas cimport RowMajor, ColMajor, Trans, NoTrans


from sklearn.utils cimport _random

ctypedef np.float64_t DOUBLE
ctypedef np.uint32_t UINT32_t

Expand All @@ -38,17 +40,9 @@ cdef enum:
RAND_R_MAX = 0x7FFFFFFF


cdef inline UINT32_t our_rand_r(UINT32_t* seed) nogil:
seed[0] ^= <UINT32_t>(seed[0] << 13)
seed[0] ^= <UINT32_t>(seed[0] >> 17)
seed[0] ^= <UINT32_t>(seed[0] << 5)

return seed[0] % (<UINT32_t>RAND_R_MAX + 1)


cdef inline UINT32_t rand_int(UINT32_t end, UINT32_t* random_state) nogil:
"""Generate a random integer in [0; end)."""
return our_rand_r(random_state) % end
return _random.our_rand_r(random_state) % end


cdef inline floating fmax(floating x, floating y) nogil:
Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/passive_aggressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ class PassiveAggressiveClassifier(BaseSGDClassifier):
random_state=0, shuffle=True, tol=0.001,
validation_fraction=0.1, verbose=0, warm_start=False)
>>> print(clf.coef_)
[[-0.6543424 1.54603022 1.35361642 0.22199435]]
[[0.26642044 0.45070924 0.67251877 0.64185414]]
>>> print(clf.intercept_)
[0.63310933]
[1.84127814]
>>> print(clf.predict([[0, 0, 0, 0]]))
[1]

Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class Perceptron(BaseSGDClassifier):
penalty=None, random_state=0, shuffle=True, tol=0.001,
validation_fraction=0.1, verbose=0, warm_start=False)
>>> clf.score(X, y) # doctest: +ELLIPSIS
0.946...
0.939...

See also
--------
Expand Down
33 changes: 25 additions & 8 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
DEFAULT_EPSILON = 0.1
# Default value of ``epsilon`` parameter.

MAX_INT = np.iinfo(np.int32).max


class _ValidationScoreCallback:
"""Callback for early stopping based on validation score"""
Expand Down Expand Up @@ -322,7 +324,8 @@ def _prepare_fit_binary(est, y, i):


def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
pos_weight, neg_weight, sample_weight, validation_mask=None):
pos_weight, neg_weight, sample_weight, validation_mask=None,
random_state=None):
"""Fit a single binary classifier.

The i'th class is considered the "positive" class.
Expand Down Expand Up @@ -366,13 +369,22 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
validation_mask : numpy array of shape [n_samples, ] or None
Precomputed validation mask in case _fit_binary is called in the
context of a one-vs-rest reduction.

random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.
"""
# if average is not true, average_coef, and average_intercept will be
# unused
y_i, coef, intercept, average_coef, average_intercept = \
_prepare_fit_binary(est, y, i)
assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0]
dataset, intercept_decay = make_dataset(X, y_i, sample_weight)

random_state = check_random_state(random_state)
dataset, intercept_decay = make_dataset(
X, y_i, sample_weight, random_state=random_state)

penalty_type = est._get_penalty_type(est.penalty)
learning_rate_type = est._get_learning_rate_type(learning_rate)
Expand All @@ -383,11 +395,9 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter,
validation_score_cb = est._make_validation_score_cb(
validation_mask, X, y_i, sample_weight, classes=classes)

# XXX should have random_state_!
random_state = check_random_state(est.random_state)
# numpy mtrand expects a C long which is a signed 32 bit integer under
# Windows
seed = random_state.randint(0, np.iinfo(np.int32).max)
seed = random_state.randint(MAX_INT)

tol = est.tol if est.tol is not None else -np.inf

Expand Down Expand Up @@ -558,7 +568,8 @@ def _fit_binary(self, X, y, alpha, C, sample_weight,
learning_rate, max_iter,
self._expanded_class_weight[1],
self._expanded_class_weight[0],
sample_weight)
sample_weight,
random_state=self.random_state)

self.t_ += n_iter_ * X.shape[0]
self.n_iter_ = n_iter_
Expand Down Expand Up @@ -589,13 +600,19 @@ def _fit_multiclass(self, X, y, alpha, C, learning_rate,
validation_mask = self._make_validation_split(y)

# Use joblib to fit OvA in parallel.
# Pick the random seed for each job outside of fit_binary to avoid
# sharing the estimator random state between threads which could lead
# to non-deterministic behavior
random_state = check_random_state(self.random_state)
seeds = random_state.randint(MAX_INT, size=len(self.classes_))
result = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
**_joblib_parallel_args(require="sharedmem"))(
delayed(fit_binary)(self, i, X, y, alpha, C, learning_rate,
max_iter, self._expanded_class_weight[i],
1., sample_weight,
validation_mask=validation_mask)
for i in range(len(self.classes_)))
validation_mask=validation_mask,
random_state=seed)
for i, seed in enumerate(seeds))

# take the maximum of n_iter_ over every binary fit
n_iter_ = 0.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,14 +1671,14 @@ def test_elastic_net_versus_sgd(C, l1_ratio):
n_samples = 500
X, y = make_classification(n_samples=n_samples, n_classes=2, n_features=5,
n_informative=5, n_redundant=0, n_repeated=0,
random_state=0)
random_state=1)
X = scale(X)

sgd = SGDClassifier(
penalty='elasticnet', random_state=0, fit_intercept=False, tol=-np.inf,
penalty='elasticnet', random_state=1, fit_intercept=False, tol=-np.inf,
max_iter=2000, l1_ratio=l1_ratio, alpha=1. / C / n_samples, loss='log')
log = LogisticRegression(
penalty='elasticnet', random_state=0, fit_intercept=False, tol=1e-5,
penalty='elasticnet', random_state=1, fit_intercept=False, tol=1e-5,
max_iter=1000, l1_ratio=l1_ratio, C=C, solver='saga')

sgd.fit(X, y)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/tests/test_passive_aggressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_classifier_accuracy():
for average in (False, True):
clf = PassiveAggressiveClassifier(
C=1.0, max_iter=30, fit_intercept=fit_intercept,
random_state=0, average=average, tol=None)
random_state=1, average=average, tol=None)
clf.fit(data, y)
score = clf.score(data, y)
assert_greater(score, 0.79)
Expand Down
11 changes: 7 additions & 4 deletions sklearn/linear_model/tests/test_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,22 +1032,25 @@ def test_partial_fit_equal_fit_classif(klass, lr):

@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier])
def test_regression_losses(klass):
random_state = np.random.RandomState(1)
clf = klass(alpha=0.01, learning_rate="constant",
eta0=0.1, loss="epsilon_insensitive")
eta0=0.1, loss="epsilon_insensitive",
random_state=random_state)
clf.fit(X, Y)
assert_equal(1.0, np.mean(clf.predict(X) == Y))

clf = klass(alpha=0.01, learning_rate="constant",
eta0=0.1, loss="squared_epsilon_insensitive")
eta0=0.1, loss="squared_epsilon_insensitive",
random_state=random_state)
clf.fit(X, Y)
assert_equal(1.0, np.mean(clf.predict(X) == Y))

clf = klass(alpha=0.01, loss="huber")
clf = klass(alpha=0.01, loss="huber", random_state=random_state)
clf.fit(X, Y)
assert_equal(1.0, np.mean(clf.predict(X) == Y))

clf = klass(alpha=0.01, learning_rate="constant", eta0=0.01,
loss="squared_loss")
loss="squared_loss", random_state=random_state)
clf.fit(X, Y)
assert_equal(1.0, np.mean(clf.predict(X) == Y))

Expand Down
4 changes: 2 additions & 2 deletions sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,14 @@ def test_multi_output_classification_partial_fit_sample_weights():
Xw = [[1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]]
yw = [[3, 2], [2, 3], [3, 2]]
w = np.asarray([2., 1., 1.])
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5)
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=20)
clf_w = MultiOutputClassifier(sgd_linear_clf)
clf_w.fit(Xw, yw, w)

# unweighted, but with repeated samples
X = [[1, 2, 3], [1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]]
y = [[3, 2], [3, 2], [2, 3], [3, 2]]
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5)
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=20)
clf = MultiOutputClassifier(sgd_linear_clf)
clf.fit(X, y)
X_test = [[1.5, 2.5, 3.5]]
Expand Down
1 change: 1 addition & 0 deletions sklearn/tree/_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer


cdef enum:
# Max value for our rand_r replacement (near the bottom).
# We don't use RAND_MAX because it's different across platforms and
Expand Down
16 changes: 4 additions & 12 deletions sklearn/tree/_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import numpy as np
cimport numpy as np
np.import_array()

from sklearn.utils cimport _random

# =============================================================================
# Helper functions
# =============================================================================
Expand Down Expand Up @@ -53,16 +55,6 @@ def _realloc_test():
assert False


# rand_r replacement using a 32bit XorShift generator
# See https://www.jstatsoft.org/v08/i14/paper for details
cdef inline UINT32_t our_rand_r(UINT32_t* seed) nogil:
seed[0] ^= <UINT32_t>(seed[0] << 13)
seed[0] ^= <UINT32_t>(seed[0] >> 17)
seed[0] ^= <UINT32_t>(seed[0] << 5)

return seed[0] % (<UINT32_t>RAND_R_MAX + 1)


cdef inline np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size):
"""Return copied data as 1D numpy array of intp's."""
cdef np.npy_intp shape[1]
Expand All @@ -73,13 +65,13 @@ cdef inline np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size):
cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high,
UINT32_t* random_state) nogil:
"""Generate a random integer in [low; end)."""
return low + our_rand_r(random_state) % (high - low)
return low + _random.our_rand_r(random_state) % (high - low)


cdef inline double rand_uniform(double low, double high,
UINT32_t* random_state) nogil:
"""Generate a random double in [low; high)."""
return ((high - low) * <double> our_rand_r(random_state) /
return ((high - low) * <double> _random.our_rand_r(random_state) /
<double> RAND_R_MAX) + low


Expand Down
Loading