@@ -607,21 +607,23 @@ def check_estimator(estimator=None, generate_only=False, *, legacy: bool = True)
607607
608608 name = type (estimator ).__name__
609609
610- def checks_generator ():
610+ def checks_generator (reference_estimator ):
611611 # we first need to check if the estimator is cloneable for the rest of the tests
612612 # to run
613613 yield estimator , partial (check_estimator_cloneable , name )
614614 for check in _yield_all_checks (estimator , legacy = legacy ):
615- for check_instance in _yield_instances_for_check (check , estimator ):
616- maybe_skipped_check = _maybe_skip (check_instance , check )
617- yield check_instance , partial (maybe_skipped_check , name )
615+ for check_specific_estimator in _yield_instances_for_check (
616+ check , reference_estimator
617+ ):
618+ maybe_skipped_check = _maybe_skip (check_specific_estimator , check )
619+ yield check_specific_estimator , partial (maybe_skipped_check , name )
618620
619621 if generate_only :
620- return checks_generator ()
622+ return checks_generator (estimator )
621623
622- for estimator , check in checks_generator ():
624+ for check_specific_estimator , check in checks_generator (estimator ):
623625 try :
624- check (estimator )
626+ check (check_specific_estimator )
625627 except SkipTest as exception :
626628 # SkipTest is thrown when pandas can't be imported, or by checks
627629 # that are in the xfail_checks tag
0 commit comments