|
53 | 53 | from sklearn.exceptions import DataConversionWarning |
54 | 54 | from sklearn.exceptions import SkipTestWarning |
55 | 55 | from sklearn.model_selection import train_test_split |
| 56 | +from sklearn.model_selection import ShuffleSplit |
| 57 | +from sklearn.model_selection._validation import _safe_split |
56 | 58 | from sklearn.metrics.pairwise import (rbf_kernel, linear_kernel, |
57 | 59 | pairwise_distances) |
58 | 60 |
|
@@ -265,6 +267,7 @@ def _yield_all_checks(name, estimator): |
265 | 267 | yield check_set_params |
266 | 268 | yield check_dict_unchanged |
267 | 269 | yield check_dont_overwrite_parameters |
| 270 | + yield check_fit_idempotent |
268 | 271 |
|
269 | 272 |
|
270 | 273 | def check_estimator(Estimator): |
@@ -2330,3 +2333,50 @@ def check_outliers_fit_predict(name, estimator_orig): |
2330 | 2333 | for contamination in [-0.5, 2.3]: |
2331 | 2334 | estimator.set_params(contamination=contamination) |
2332 | 2335 | assert_raises(ValueError, estimator.fit_predict, X) |
| 2336 | + |
| 2337 | + |
| 2338 | +def check_fit_idempotent(name, estimator_orig): |
| 2339 | + # Check that est.fit(X) is the same as est.fit(X).fit(X). Ideally we would |
| 2340 | + # check that the estimated parameters during training (e.g. coefs_) are |
| 2341 | + # the same, but having a universal comparison function for those |
| 2342 | + # attributes is difficult and full of edge cases. So instead we check that |
| 2343 | + # predict(), predict_proba(), decision_function() and transform() return |
| 2344 | + # the same results. |
| 2345 | + |
| 2346 | + check_methods = ["predict", "transform", "decision_function", |
| 2347 | + "predict_proba"] |
| 2348 | + rng = np.random.RandomState(0) |
| 2349 | + |
| 2350 | + estimator = clone(estimator_orig) |
| 2351 | + set_random_state(estimator) |
| 2352 | + if 'warm_start' in estimator.get_params().keys(): |
| 2353 | + estimator.set_params(warm_start=False) |
| 2354 | + |
| 2355 | + n_samples = 100 |
| 2356 | + X = rng.normal(loc=100, size=(n_samples, 2)) |
| 2357 | + X = pairwise_estimator_convert_X(X, estimator) |
| 2358 | + if is_regressor(estimator_orig): |
| 2359 | + y = rng.normal(size=n_samples) |
| 2360 | + else: |
| 2361 | + y = rng.randint(low=0, high=2, size=n_samples) |
| 2362 | + y = multioutput_estimator_convert_y_2d(estimator, y) |
| 2363 | + |
| 2364 | + train, test = next(ShuffleSplit(test_size=.2, random_state=rng).split(X)) |
| 2365 | + X_train, y_train = _safe_split(estimator, X, y, train) |
| 2366 | + X_test, y_test = _safe_split(estimator, X, y, test, train) |
| 2367 | + |
| 2368 | + # Fit for the first time |
| 2369 | + estimator.fit(X_train, y_train) |
| 2370 | + |
| 2371 | + result = {} |
| 2372 | + for method in check_methods: |
| 2373 | + if hasattr(estimator, method): |
| 2374 | + result[method] = getattr(estimator, method)(X_test) |
| 2375 | + |
| 2376 | + # Fit again |
| 2377 | + estimator.fit(X_train, y_train) |
| 2378 | + |
| 2379 | + for method in check_methods: |
| 2380 | + if hasattr(estimator, method): |
| 2381 | + new_result = getattr(estimator, method)(X_test) |
| 2382 | + assert_allclose_dense_sparse(result[method], new_result) |
0 commit comments