Skip to content

Commit f6b0c67

Browse files
NicolasHugTomDLT
authored andcommitted
TST Added estimator check for idempotence of fit() (#12328)
1 parent d4802ae commit f6b0c67

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

doc/whats_new/v0.21.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,9 @@ Changes to estimator checks
115115
---------------------------
116116

117117
These changes mostly affect library developers.
118+
119+
- Add ``check_fit_idempotent`` to
120+
:func:`~utils.estimator_checks.check_estimator`, which checks that
121+
when `fit` is called twice with the same data, the ouput of
122+
`predit`, `predict_proba`, `transform`, and `decision_function` does not
123+
change. :issue:`12328` by :user:`Nicolas Hug<NicolasHug>`

sklearn/utils/estimator_checks.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
from sklearn.exceptions import DataConversionWarning
5454
from sklearn.exceptions import SkipTestWarning
5555
from sklearn.model_selection import train_test_split
56+
from sklearn.model_selection import ShuffleSplit
57+
from sklearn.model_selection._validation import _safe_split
5658
from sklearn.metrics.pairwise import (rbf_kernel, linear_kernel,
5759
pairwise_distances)
5860

@@ -265,6 +267,7 @@ def _yield_all_checks(name, estimator):
265267
yield check_set_params
266268
yield check_dict_unchanged
267269
yield check_dont_overwrite_parameters
270+
yield check_fit_idempotent
268271

269272

270273
def check_estimator(Estimator):
@@ -2330,3 +2333,50 @@ def check_outliers_fit_predict(name, estimator_orig):
23302333
for contamination in [-0.5, 2.3]:
23312334
estimator.set_params(contamination=contamination)
23322335
assert_raises(ValueError, estimator.fit_predict, X)
2336+
2337+
2338+
def check_fit_idempotent(name, estimator_orig):
2339+
# Check that est.fit(X) is the same as est.fit(X).fit(X). Ideally we would
2340+
# check that the estimated parameters during training (e.g. coefs_) are
2341+
# the same, but having a universal comparison function for those
2342+
# attributes is difficult and full of edge cases. So instead we check that
2343+
# predict(), predict_proba(), decision_function() and transform() return
2344+
# the same results.
2345+
2346+
check_methods = ["predict", "transform", "decision_function",
2347+
"predict_proba"]
2348+
rng = np.random.RandomState(0)
2349+
2350+
estimator = clone(estimator_orig)
2351+
set_random_state(estimator)
2352+
if 'warm_start' in estimator.get_params().keys():
2353+
estimator.set_params(warm_start=False)
2354+
2355+
n_samples = 100
2356+
X = rng.normal(loc=100, size=(n_samples, 2))
2357+
X = pairwise_estimator_convert_X(X, estimator)
2358+
if is_regressor(estimator_orig):
2359+
y = rng.normal(size=n_samples)
2360+
else:
2361+
y = rng.randint(low=0, high=2, size=n_samples)
2362+
y = multioutput_estimator_convert_y_2d(estimator, y)
2363+
2364+
train, test = next(ShuffleSplit(test_size=.2, random_state=rng).split(X))
2365+
X_train, y_train = _safe_split(estimator, X, y, train)
2366+
X_test, y_test = _safe_split(estimator, X, y, test, train)
2367+
2368+
# Fit for the first time
2369+
estimator.fit(X_train, y_train)
2370+
2371+
result = {}
2372+
for method in check_methods:
2373+
if hasattr(estimator, method):
2374+
result[method] = getattr(estimator, method)(X_test)
2375+
2376+
# Fit again
2377+
estimator.fit(X_train, y_train)
2378+
2379+
for method in check_methods:
2380+
if hasattr(estimator, method):
2381+
new_result = getattr(estimator, method)(X_test)
2382+
assert_allclose_dense_sparse(result[method], new_result)

0 commit comments

Comments
 (0)