Skip to content
4 changes: 4 additions & 0 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ no_validation (default=False)
non_deterministic (default=False)
whether the estimator is not deterministic given a fixed ``random_state``

preserves_dtype (default=``[np.float64]``)
whether an estimator preserves the specified dtypes if given as input. Possible
options are a combination of `np.flaot16`, `np.float32` and `np.float64`.

poor_score (default=False)
whether the estimator fails to provide a "reasonable" test-set score, which
currently for regression is an R2 of 0.5 on a subset of the boston housing
Expand Down
1 change: 1 addition & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
'multioutput_only': False,
'binary_only': False,
'requires_fit': True,
'preserves_dtype': [np.float64, np.float32],
'requires_y': False,
}

Expand Down
3 changes: 3 additions & 0 deletions sklearn/decomposition/_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,6 @@ def inverse_transform(self, X):
n_samples = self.X_transformed_fit_.shape[0]
K.flat[::n_samples + 1] += self.alpha
return np.dot(K, self.dual_coef_)

def _more_tags(self):
return {'preserves_dtype': [np.float64, np.float32]}
3 changes: 3 additions & 0 deletions sklearn/decomposition/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,3 +613,6 @@ def score(self, X, y=None):
Average log-likelihood of the samples under the current model.
"""
return np.mean(self.score_samples(X))

def _more_tags(self):
return {'preserves_dtype': [np.float64, np.float32]}
3 changes: 3 additions & 0 deletions sklearn/decomposition/_truncated_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,6 @@ def inverse_transform(self, X):
"""
X = check_array(X)
return np.dot(X, self.components_)

def _more_tags(self):
return {'preserves_dtype': [np.float64, np.float32]}
3 changes: 3 additions & 0 deletions sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,6 @@ def _get_kernel_params(self):
"or precomputed kernel")

return params

def _more_tags(self):
return {'preserves_dtype': [np.float64, np.float32]}
3 changes: 2 additions & 1 deletion sklearn/preprocessing/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,8 @@ def inverse_transform(self, X, copy=None):
return X

def _more_tags(self):
return {'allow_nan': True}
return {'allow_nan': True,
'preserves_dtype': [np.float64, np.float32]}


class MaxAbsScaler(TransformerMixin, BaseEstimator):
Expand Down
69 changes: 67 additions & 2 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
from ..linear_model import Ridge

from ..base import (clone, ClusterMixin, is_classifier, is_regressor,
RegressorMixin, is_outlier_detector)
RegressorMixin, is_outlier_detector,
MetaEstimatorMixin)
from ..impute import MissingIndicator
from ..kernel_approximation import SkewedChi2Sampler

from ..metrics import accuracy_score, adjusted_rand_score, f1_score
from ..random_projection import BaseRandomProjection
Expand All @@ -58,6 +61,7 @@
CROSS_DECOMPOSITION = ['PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']



def _yield_checks(estimator):
name = estimator.__class__.__name__
tags = estimator._get_tags()
Expand Down Expand Up @@ -195,6 +199,11 @@ def _yield_transformer_checks(transformer):
yield check_transformer_data_not_an_array
# these don't actually fit the data, so don't raise errors
yield check_transformer_general
# it's not possible to preserve dtypes in transform with clustering
# same for MissingIndicator
if (not isinstance(transformer, (ClusterMixin, MissingIndicator)) and
_safe_tags(transformer, "preserves_dtype")):
yield check_estimators_preserve_dtypes
yield partial(check_transformer_general, readonly_memmap=True)
if not transformer._get_tags()["stateless"]:
yield check_transformers_unfitted
Expand Down Expand Up @@ -1422,6 +1431,62 @@ def check_estimators_dtypes(name, estimator_orig):
getattr(estimator, method)(X_train)


def check_estimators_preserve_dtypes(name, estimator_orig):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please skip this test for all estimators that currently fail,

if name in [...]:
   raise SkipTest('Known failure to preserve dtypes')

so we can merge this.

The XFAIL support from #16306 could be used once that PR is merged.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(estimator_orig, MetaEstimatorMixin):
if hasattr(estimator_orig, 'estimator'):
base_estimator = estimator_orig.estimator
elif hasattr(estimator_orig, 'base_estimator'):
base_estimator = estimator_orig.base_estimator
else:
base_estimator = estimator_orig
if is_regressor(base_estimator):
X, y = make_regression(n_samples=50, n_features=5)
else:
X, y = make_classification(n_samples=50, n_features=5)
# SkewedChi2Sampler requires values values larger than -skewedness
if (_safe_tags(base_estimator, "requires_positive_X") or
isinstance(base_estimator, SkewedChi2Sampler)):
X = np.absolute(X)
y = _enforce_estimator_tags_y(base_estimator, y)
X = _pairwise_estimator_convert_X(X, estimator_orig)
X = X.astype(np.float32)

Xts = []
in_out_types = _safe_tags(estimator_orig, 'preserves_dtype')
for dtype in in_out_types:
X_cast = X.astype(dtype)
estimator = clone(estimator_orig)
set_random_state(estimator)
if hasattr(estimator, 'fit_transform'):
X_trans = estimator.fit_transform(X_cast, y)
elif hasattr(estimator, 'fit'):
estimator.fit(X_cast, y)
X_trans = estimator.transform(X_cast)

if sparse.issparse(X_trans):
X_trans = X_trans.toarray()
# Cross Decompostion returns a tuple of (x_scores, y_scores)
# when given y with fit_transform
if isinstance(X_trans, tuple):
X_trans = X_trans[0]
# FIXME: should we check that the dtype of some attributes are the
# same than dtype and check that the value of attributes
# between 32bit and 64bit are close
assert X_trans.dtype == dtype, \
('Estimator transform dtype: {} - orginal/expected dtype: {}'
.format(X_trans.dtype, dtype.__name__))
Xts.append(X_trans)

# We assume the transformer is on float64 input correct and
# compare all other inputs against them.
for i in range(1, len(Xts)):
assert_allclose(Xts[i], Xts[0], rtol=1e-4,
err_msg='dtype_in: {} dtype_ground_truth: {}\n'
.format(in_out_types[i].__name__,
in_out_types[0].__name__))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end maybe we should add the check comparing transforms in this PR, say using,

   assert_allclose(X_trans_32, X_trans_64, rtol=1e-2)

with a high enough tolerance. Before we start modifying other estimators to pass this.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why such a high tolerance ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, maybe that was too pessimistic. Maybve rtol=1e-4 then, what value would be good you think?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't know :/ Ideally it would be around 1e-6, but it's not realistic.
The output precision depends of the algorithm and we probably can't expect machine precision for all algorithms even with the best possible implementation.
I put 1e-5 in the tests of our wrappers of scipy blas but I'm fine with 1e-4


@ignore_warnings(category=FutureWarning)
def check_estimators_empty_data_messages(name, estimator_orig):
e = clone(estimator_orig)
Expand Down Expand Up @@ -1450,7 +1515,7 @@ def check_estimators_nan_inf(name, estimator_orig):
# Checks that Estimator X's do not contain NaN or inf.
rnd = np.random.RandomState(0)
X_train_finite = _pairwise_estimator_convert_X(rnd.uniform(size=(10, 3)),
estimator_orig)
estimator_orig)
X_train_nan = rnd.uniform(size=(10, 3))
X_train_nan[0, 0] = np.nan
X_train_inf = rnd.uniform(size=(10, 3))
Expand Down