@@ -252,6 +252,35 @@ def _check_transformer(name, Transformer, X, y):
252252 assert_raises (ValueError , transformer .transform , X .T )
253253
254254
255+ def check_estimators_dtypes (name , Estimator ):
256+ rnd = np .random .RandomState (0 )
257+ X_train_32 = 4 * rnd .uniform (size = (10 , 3 )).astype (np .float32 )
258+ X_train_64 = X_train_32 .astype (np .float64 )
259+ X_train_int_64 = X_train_32 .astype (np .int64 )
260+ X_train_int_32 = X_train_32 .astype (np .int32 )
261+ y = X_train_int_64 [:, 0 ]
262+ y = multioutput_estimator_convert_y_2d (name , y )
263+ for X_train in [X_train_32 , X_train_64 , X_train_int_64 , X_train_int_32 ]:
264+ with warnings .catch_warnings (record = True ):
265+ estimator = Estimator ()
266+ set_fast_parameters (estimator )
267+ set_random_state (estimator , 1 )
268+ if issubclass (Estimator , ClusterMixin ):
269+ estimator .fit (X_train )
270+ else :
271+ estimator .fit (X_train , y )
272+
273+ for method in ["predict" , "transform" , "decision_function" ,
274+ "predict_proba" ]:
275+ try :
276+ if hasattr (estimator , method ):
277+ getattr (estimator , method )(X_train )
278+ except NotImplementedError :
279+ # FIXME
280+ # non-standard handling of ducktyping in BaggingEstimator
281+ pass
282+
283+
255284def check_estimators_nan_inf (name , Estimator ):
256285 rnd = np .random .RandomState (0 )
257286 X_train_finite = rnd .uniform (size = (10 , 3 ))
0 commit comments