Skip to content

Commit c2f0b31

Browse files
committed
FIX check (and enforce) that estimators can accept different dtypes.
1 parent 31c5497 commit c2f0b31

File tree

5 files changed

+39
-7
lines changed

5 files changed

+39
-7
lines changed

sklearn/cluster/spectral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def spectral_clustering(affinity, n_clusters=8, n_components=None,
243243
This algorithm solves the normalized cut for k=2: it is a
244244
normalized spectral clustering.
245245
"""
246-
if not assign_labels in ('kmeans', 'discretize'):
246+
if assign_labels not in ('kmeans', 'discretize'):
247247
raise ValueError("The 'assign_labels' parameter should be "
248248
"'kmeans' or 'discretize', but '%s' was given"
249249
% assign_labels)
@@ -415,7 +415,7 @@ def fit(self, X):
415415
OR, if affinity==`precomputed`, a precomputed affinity
416416
matrix of shape (n_samples, n_samples)
417417
"""
418-
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'])
418+
X = check_array(X, accept_sparse=['csr', 'csc', 'coo'], dtype=np.float)
419419
if X.shape[0] == X.shape[1] and self.affinity != "precomputed":
420420
warnings.warn("The spectral clustering API has changed. ``fit``"
421421
"now constructs an affinity matrix from data. To use"

sklearn/linear_model/coordinate_descent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,6 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
389389
if selection not in ['random', 'cyclic']:
390390
raise ValueError("selection should be either random or cyclic.")
391391
random = (selection == 'random')
392-
models = []
393392

394393
if not multi_output:
395394
coefs = np.empty((n_features, n_alphas), dtype=np.float64)
@@ -1016,7 +1015,7 @@ def fit(self, X, y):
10161015
# Let us not impose fortran ordering or float64 so far: it is
10171016
# not useful for the cross-validation loop and will be done
10181017
# by the model fitting itself
1019-
X = check_array(X, 'csc', copy=False)
1018+
X = check_array(X, 'csc', copy=False, dtype=np.float64)
10201019
if sparse.isspmatrix(X):
10211020
if not np.may_share_memory(reference_to_old_X.data, X.data):
10221021
# X is a sparse matrix and has been copied
@@ -1418,6 +1417,7 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
14181417
self.random_state = random_state
14191418
self.selection = selection
14201419

1420+
14211421
###############################################################################
14221422
# Multi Task ElasticNet and Lasso models (with joint feature selection)
14231423

sklearn/linear_model/omp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def fit(self, X, y):
805805
self : object
806806
returns an instance of self.
807807
"""
808-
X, y = check_X_y(X, y)
808+
X, y = check_X_y(X, y, dtype=np.float)
809809
cv = check_cv(self.cv, X, y, classifier=False)
810810
max_iter = (min(max(int(0.1 * X.shape[1]), 5), X.shape[1])
811811
if not self.max_iter

sklearn/tests/test_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sklearn.utils.estimator_checks import (
3131
check_parameters_default_constructible,
3232
check_estimator_sparse_data,
33+
check_estimators_dtypes,
3334
check_transformer,
3435
check_clustering,
3536
check_clusterer_compute_labels_predict,
@@ -87,12 +88,14 @@ def test_non_meta_estimators():
8788
estimators = all_estimators(type_filter=['classifier', 'regressor',
8889
'transformer', 'cluster'])
8990
for name, Estimator in estimators:
91+
if name not in CROSS_DECOMPOSITION + ['SelectFdr']:
92+
yield check_estimators_dtypes, name, Estimator
93+
9094
if name not in CROSS_DECOMPOSITION + ['Imputer']:
9195
# Test that all estimators check their input for NaN's and infs
9296
yield check_estimators_nan_inf, name, Estimator
9397

94-
if (name not in ['CCA', '_CCA', 'PLSCanonical', 'PLSRegression',
95-
'PLSSVD', 'GaussianProcess']):
98+
if name not in CROSS_DECOMPOSITION + ['GaussianProcess']:
9699
# FIXME!
97100
# in particular GaussianProcess!
98101
yield check_estimators_overwrite_params, name, Estimator

sklearn/utils/estimator_checks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,35 @@ def _check_transformer(name, Transformer, X, y):
252252
assert_raises(ValueError, transformer.transform, X.T)
253253

254254

255+
def check_estimators_dtypes(name, Estimator):
256+
rnd = np.random.RandomState(0)
257+
X_train_32 = 4 * rnd.uniform(size=(10, 3)).astype(np.float32)
258+
X_train_64 = X_train_32.astype(np.float64)
259+
X_train_int_64 = X_train_32.astype(np.int64)
260+
X_train_int_32 = X_train_32.astype(np.int32)
261+
y = X_train_int_64[:, 0]
262+
y = multioutput_estimator_convert_y_2d(name, y)
263+
for X_train in [X_train_32, X_train_64, X_train_int_64, X_train_int_32]:
264+
with warnings.catch_warnings(record=True):
265+
estimator = Estimator()
266+
set_fast_parameters(estimator)
267+
set_random_state(estimator, 1)
268+
if issubclass(Estimator, ClusterMixin):
269+
estimator.fit(X_train)
270+
else:
271+
estimator.fit(X_train, y)
272+
273+
for method in ["predict", "transform", "decision_function",
274+
"predict_proba"]:
275+
try:
276+
if hasattr(estimator, method):
277+
getattr(estimator, method)(X_train)
278+
except NotImplementedError:
279+
# FIXME
280+
# non-standard handling of ducktyping in BaggingEstimator
281+
pass
282+
283+
255284
def check_estimators_nan_inf(name, Estimator):
256285
rnd = np.random.RandomState(0)
257286
X_train_finite = rnd.uniform(size=(10, 3))

0 commit comments

Comments
 (0)