Skip to content

Conversation

@adrinjalali
Copy link
Member

@adrinjalali adrinjalali commented Oct 21, 2024

Closes #16469

Moving _estimator_tpye to tags.

Right now this doesn't work.

@github-actions
Copy link

github-actions bot commented Oct 21, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 216f442. Link to the linter CI: here

# Test that the best estimator contains the right value for foo_param
clf = MockClassifier()
grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3)
grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=2, verbose=3)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the data has only two samples for each class, and in classification case we use stratified cv, which requires cv <= n_samples per class. This wasn't an issue so far cause our MockClassifier wasn't declaring that it's a classifier.

# The number of samples per class needs to be > n_splits,
# for StratifiedKFold(n_splits=3)
y2 = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 3])
y2 = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue, now we need 5 samples for each class since devault cv is 5.

) not in sig:
continue

for meta_estimator in _construct_instances(Estimator):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta estimators' type sometimes depends on their sub-estimator, and we should run the test for all their instances.

@adrinjalali adrinjalali marked this pull request as ready for review October 21, 2024 14:57
@adrinjalali
Copy link
Member Author

cc @Charlie-XIAO @adam2392

@adrinjalali adrinjalali added the Developer API Third party developer API related label Oct 21, 2024
@adrinjalali adrinjalali added this to the 1.6 milestone Oct 21, 2024
Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to see this getting consolidated!

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good. One merge, I'll test it in imbalanced-learn. But I think that having the logic in the tag makes things easier sometimes when we want to extend some part.

)


def _get_instance_with_pipeline(meta_estimator, init_params):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhm codecov is complaining of not covered line here. I would not expect it since this should be some code that we have before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't complaining anymore, I think.

Copy link
Member Author

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adam2392 wanna have another look here? Comments should be resolved now.

exists = True
item += para
lst += item
if est.__sklearn_tags__().input_tags.allow_nan:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noticed that this section wasn't in the scope of the suppress(SkipTest), which it should be

Comment on lines +206 to +208
elif Estimator.__name__ == "FrozenEstimator":
X, y = make_classification(n_samples=20, n_features=5, random_state=0)
est = Estimator(LogisticRegression().fit(X, y))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noticed we didn't have this while checking our usages of _construct_instance

@adam2392 adam2392 self-requested a review October 31, 2024 22:03
Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry few more questions after reading through the code (I wanna do a thorough job :p). These are mostly small comments. Once addressed (assuming those aren't massive changes), then I can approve and merge.

Comment on lines +390 to +396
return Tags(
estimator_type=None,
target_tags=TargetTags(required=False),
transformer_tags=None,
regressor_tags=None,
classifier_tags=None,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't these the default_tags anymore? Since it's BaseEstimator, I assumed that would be "default".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cause we need to remove default_tags once we're done with the deprecation period. default_tags depends on detecting the type of estimator from the class, and not the instance, which we're removing in this PR.

Comment on lines +289 to +291
est_is_classifier = getattr(estimator, "_estimator_type", None) == "classifier"
est_is_regressor = getattr(estimator, "_estimator_type", None) == "regressor"
target_required = est_is_classifier or est_is_regressor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once _estimator_type is removed, what will we do here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll be removing default_tags alltogether.

Comment on lines +390 to +396
return Tags(
estimator_type=None,
target_tags=TargetTags(required=False),
transformer_tags=None,
regressor_tags=None,
classifier_tags=None,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cause we need to remove default_tags once we're done with the deprecation period. default_tags depends on detecting the type of estimator from the class, and not the instance, which we're removing in this PR.

Comment on lines +289 to +291
est_is_classifier = getattr(estimator, "_estimator_type", None) == "classifier"
est_is_regressor = getattr(estimator, "_estimator_type", None) == "regressor"
target_required = est_is_classifier or est_is_regressor
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll be removing default_tags alltogether.

input_tags: InputTags = field(default_factory=InputTags)


# TODO(1.8): Remove this function
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adam2392 note this comment 😁

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see 😅

Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thanks @adrinjalali

@adam2392 adam2392 merged commit 613cff9 into scikit-learn:main Nov 5, 2024
30 checks passed
@adrinjalali adrinjalali deleted the estimator_type branch November 5, 2024 12:01
@glemaitre
Copy link
Member

glemaitre commented Nov 5, 2024

I open #30227 to solve a bug that we did not catch since we did not build the full documentation.

I added an entry in the changelog because things could have gone wrong even before with the wrong ordering of the mixin.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Development

Successfully merging this pull request may close these issues.

Deprecate _estimator_type, replace by estimator tag

3 participants