@@ -68,7 +68,8 @@ def _yield_checks(name, estimator):
6868 yield check_sample_weights_not_an_array
6969 yield check_sample_weights_list
7070 yield check_sample_weights_shape
71- yield check_sample_weights_invariance
71+ yield partial (check_sample_weights_invariance , kind = 'ones' )
72+ yield partial (check_sample_weights_invariance , kind = 'zeros' )
7273 yield check_estimators_fit_returns_self
7374 yield partial (check_estimators_fit_returns_self , readonly_memmap = True )
7475
@@ -488,6 +489,7 @@ def check_estimator(Estimator, generate_only=False):
488489 warnings .warn (msg , FutureWarning )
489490
490491 checks_generator = _generate_class_checks (Estimator )
492+ estimator = _construct_instance (Estimator )
491493 else :
492494 # got an instance
493495 estimator = Estimator
@@ -497,12 +499,19 @@ def check_estimator(Estimator, generate_only=False):
497499 if generate_only :
498500 return checks_generator
499501
502+ xfail_checks = _safe_tags (estimator , '_xfail_test' )
503+
500504 for estimator , check in checks_generator :
505+ check_name = _set_check_estimator_ids (check )
506+ if xfail_checks and check_name in xfail_checks :
507+ # skip tests marked as a known failure and raise a warning
508+ msg = xfail_checks [check_name ]
509+ warnings .warn (f'Skipping { check_name } : { msg } ' , SkipTestWarning )
510+ continue
501511 try :
502512 check (estimator )
503513 except SkipTest as exception :
504- # the only SkipTest thrown currently results from not
505- # being able to import pandas.
514+ # raise warning for tests that are are skipped
506515 warnings .warn (str (exception ), SkipTestWarning )
507516
508517
@@ -861,7 +870,7 @@ def check_sample_weights_shape(name, estimator_orig):
861870
862871
863872@ignore_warnings (category = FutureWarning )
864- def check_sample_weights_invariance (name , estimator_orig ):
873+ def check_sample_weights_invariance (name , estimator_orig , kind = "ones" ):
865874 # check that the estimators yield same results for
866875 # unit weights and no weights
867876 if (has_fit_parameter (estimator_orig , "sample_weight" ) and
@@ -877,25 +886,45 @@ def check_sample_weights_invariance(name, estimator_orig):
877886 X = np .array ([[1 , 3 ], [1 , 3 ], [1 , 3 ], [1 , 3 ],
878887 [2 , 1 ], [2 , 1 ], [2 , 1 ], [2 , 1 ],
879888 [3 , 3 ], [3 , 3 ], [3 , 3 ], [3 , 3 ],
880- [4 , 1 ], [4 , 1 ], [4 , 1 ], [4 , 1 ]], dtype = np .dtype ( 'float' ) )
889+ [4 , 1 ], [4 , 1 ], [4 , 1 ], [4 , 1 ]], dtype = np .float64 )
881890 y = np .array ([1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ,
882- 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ], dtype = np .dtype ('int' ))
891+ 1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ], dtype = np .int )
892+
893+ if kind == 'ones' :
894+ X2 = X
895+ y2 = y
896+ sw2 = np .ones (shape = len (y ))
897+ err_msg = (f"For { name } sample_weight=None is not equivalent to "
898+ f"sample_weight=ones" )
899+ elif kind == 'zeros' :
900+ # Construct a dataset that is very different to (X, y) if weights
901+ # are disregarded, but identical to (X, y) given weights.
902+ X2 = np .vstack ([X , X + 1 ])
903+ y2 = np .hstack ([y , 3 - y ])
904+ sw2 = np .ones (shape = len (y ) * 2 )
905+ sw2 [len (y ):] = 0
906+ X2 , y2 , sw2 = shuffle (X2 , y2 , sw2 , random_state = 0 )
907+
908+ err_msg = (f"For { name } sample_weight is not equivalent "
909+ f"to removing samples" )
910+ else :
911+ raise ValueError
912+
883913 y = _enforce_estimator_tags_y (estimator1 , y )
914+ y2 = _enforce_estimator_tags_y (estimator2 , y2 )
884915
885- estimator1 .fit (X , y = y , sample_weight = np . ones ( shape = len ( y )) )
886- estimator2 .fit (X , y = y , sample_weight = None )
916+ estimator1 .fit (X , y = y , sample_weight = None )
917+ estimator2 .fit (X2 , y = y2 , sample_weight = sw2 )
887918
888- for method in ["predict" , "transform" ]:
919+ for method in ["predict" , "predict_proba" ,
920+ "decision_function" , "transform" ]:
889921 if hasattr (estimator_orig , method ):
890922 X_pred1 = getattr (estimator1 , method )(X )
891923 X_pred2 = getattr (estimator2 , method )(X )
892924 if sparse .issparse (X_pred1 ):
893925 X_pred1 = X_pred1 .toarray ()
894926 X_pred2 = X_pred2 .toarray ()
895- assert_allclose (X_pred1 , X_pred2 ,
896- err_msg = "For %s sample_weight=None is not"
897- " equivalent to sample_weight=ones"
898- % name )
927+ assert_allclose (X_pred1 , X_pred2 , err_msg = err_msg )
899928
900929
901930@ignore_warnings (category = (FutureWarning , UserWarning ))
0 commit comments