-
-
Notifications
You must be signed in to change notification settings - Fork 26.9k
[MRG] Add common test and estimator tag for preserving float32 dtype in transformers #16290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
05fc904
1d037f1
0f5fd23
b76e1c2
4ec84b6
d687a3e
e532681
d4d8971
00ac315
670bd4f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,10 @@ | |
| from ..linear_model import Ridge | ||
|
|
||
| from ..base import (clone, ClusterMixin, is_classifier, is_regressor, | ||
| RegressorMixin, is_outlier_detector) | ||
| RegressorMixin, is_outlier_detector, | ||
| MetaEstimatorMixin) | ||
| from ..impute import MissingIndicator | ||
| from ..kernel_approximation import SkewedChi2Sampler | ||
|
|
||
| from ..metrics import accuracy_score, adjusted_rand_score, f1_score | ||
| from ..random_projection import BaseRandomProjection | ||
|
|
@@ -58,6 +61,7 @@ | |
| CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD'] | ||
|
|
||
|
|
||
|
|
||
| def _yield_checks(estimator): | ||
| name = estimator.__class__.__name__ | ||
| tags = estimator._get_tags() | ||
|
|
@@ -195,6 +199,11 @@ def _yield_transformer_checks(transformer): | |
| yield check_transformer_data_not_an_array | ||
| # these don't actually fit the data, so don't raise errors | ||
| yield check_transformer_general | ||
| # it's not possible to preserve dtypes in transform with clustering | ||
| # same for MissingIndicator | ||
| if (not isinstance(transformer, (ClusterMixin, MissingIndicator)) and | ||
| _safe_tags(transformer, "preserves_dtype")): | ||
| yield check_estimators_preserve_dtypes | ||
| yield partial(check_transformer_general, readonly_memmap=True) | ||
| if not transformer._get_tags()["stateless"]: | ||
| yield check_transformers_unfitted | ||
|
|
@@ -1422,6 +1431,62 @@ def check_estimators_dtypes(name, estimator_orig): | |
| getattr(estimator, method)(X_train) | ||
|
|
||
|
|
||
| def check_estimators_preserve_dtypes(name, estimator_orig): | ||
|
|
||
| if isinstance(estimator_orig, MetaEstimatorMixin): | ||
| if hasattr(estimator_orig, 'estimator'): | ||
| base_estimator = estimator_orig.estimator | ||
| elif hasattr(estimator_orig, 'base_estimator'): | ||
| base_estimator = estimator_orig.base_estimator | ||
| else: | ||
| base_estimator = estimator_orig | ||
| if is_regressor(base_estimator): | ||
| X, y = make_regression(n_samples=50, n_features=5) | ||
| else: | ||
| X, y = make_classification(n_samples=50, n_features=5) | ||
| # SkewedChi2Sampler requires values values larger than -skewedness | ||
| if (_safe_tags(base_estimator, "requires_positive_X") or | ||
| isinstance(base_estimator, SkewedChi2Sampler)): | ||
| X = np.absolute(X) | ||
| y = _enforce_estimator_tags_y(base_estimator, y) | ||
| X = _pairwise_estimator_convert_X(X, estimator_orig) | ||
| X = X.astype(np.float32) | ||
|
|
||
| Xts = [] | ||
| in_out_types = _safe_tags(estimator_orig, 'preserves_dtype') | ||
| for dtype in in_out_types: | ||
| X_cast = X.astype(dtype) | ||
| estimator = clone(estimator_orig) | ||
| set_random_state(estimator) | ||
| if hasattr(estimator, 'fit_transform'): | ||
| X_trans = estimator.fit_transform(X_cast, y) | ||
| elif hasattr(estimator, 'fit'): | ||
| estimator.fit(X_cast, y) | ||
| X_trans = estimator.transform(X_cast) | ||
|
|
||
| if sparse.issparse(X_trans): | ||
| X_trans = X_trans.toarray() | ||
| # Cross Decompostion returns a tuple of (x_scores, y_scores) | ||
| # when given y with fit_transform | ||
| if isinstance(X_trans, tuple): | ||
| X_trans = X_trans[0] | ||
| # FIXME: should we check that the dtype of some attributes are the | ||
| # same than dtype and check that the value of attributes | ||
| # between 32bit and 64bit are close | ||
| assert X_trans.dtype == dtype, \ | ||
| ('Estimator transform dtype: {} - orginal/expected dtype: {}' | ||
| .format(X_trans.dtype, dtype.__name__)) | ||
| Xts.append(X_trans) | ||
|
|
||
| # We assume the transformer is on float64 input correct and | ||
| # compare all other inputs against them. | ||
| for i in range(1, len(Xts)): | ||
| assert_allclose(Xts[i], Xts[0], rtol=1e-4, | ||
| err_msg='dtype_in: {} dtype_ground_truth: {}\n' | ||
| .format(in_out_types[i].__name__, | ||
| in_out_types[0].__name__)) | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the end maybe we should add the check comparing transforms in this PR, say using, assert_allclose(X_trans_32, X_trans_64, rtol=1e-2)with a high enough tolerance. Before we start modifying other estimators to pass this.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why such a high tolerance ?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, maybe that was too pessimistic. Maybve rtol=1e-4 then, what value would be good you think?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really don't know :/ Ideally it would be around 1e-6, but it's not realistic. |
||
|
|
||
| @ignore_warnings(category=FutureWarning) | ||
| def check_estimators_empty_data_messages(name, estimator_orig): | ||
| e = clone(estimator_orig) | ||
|
|
@@ -1450,7 +1515,7 @@ def check_estimators_nan_inf(name, estimator_orig): | |
| # Checks that Estimator X's do not contain NaN or inf. | ||
| rnd = np.random.RandomState(0) | ||
| X_train_finite = _pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)), | ||
| estimator_orig) | ||
| estimator_orig) | ||
| X_train_nan = rnd.uniform(size=(10, 3)) | ||
| X_train_nan[0, 0] = np.nan | ||
| X_train_inf = rnd.uniform(size=(10, 3)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please skip this test for all estimators that currently fail,
so we can merge this.
The XFAIL support from #16306 could be used once that PR is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are already skipped here for those who fail. https://github.com/scikit-learn/scikit-learn/pull/16290/files#diff-a95fe0e40350c536a5e303e87ac979c4R194-R213