-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Description
According to the current documentation, GridSearchCV accepts object type that implements the “fit” and “predict” methods as the estimator parameter.
While fine for most, certain use cases are made quite unintuitive by this API.
For instance, consider the AdaBoostClassifier API. Essentially this classifier just wraps the boosting around whatever classifier is provided by base_estimator parameter. Most of the parameter tuning therefore happens in this base_estimator rather than the booster itself. If I were to use grid search for parameter tuning, I would probably do something among the lines of:
base_estimators = [DecisionTreeClassifier(max_depth=d) for d in range(1, 11)]
grid = GridSearchCV(AdaBoostClassifier(), dict(base_estimator=base_estimators))Which is already quite ugly, and I am only tuning the max_depth parameter. Imagine if I also wanted to tune some other parameter in DecisionTreeClassifier class.
One way to fix this is to make GridSearchCV accept factory functions for classifiers, and not only the classifiers themselves.
Particularly, something among the lines could make things a bit easier:
def ada_factory(*args, **kwargs):
return AdaBoostClassifier(DecisionTreeClassifier(*args, **kwargs))
grid = GridSearchCV(ada_factory, dict(max_depth=range(1,11))Obviously, the contract where the objects returned from the factory function contain fit, and predict methods should remain in place.
Not only does this solve this particular problem, it would also allow one to test multiple estimators within the same grid search -- just add a parameter to your factory function.