|
22 | 22 | # %% |
23 | 23 |
|
24 | 24 | import numpy as np |
| 25 | +import warnings |
25 | 26 | from sklearn.base import BaseEstimator |
26 | 27 | from sklearn.base import ClassifierMixin |
| 28 | +from sklearn.base import RegressorMixin |
27 | 29 | from sklearn.base import MetaEstimatorMixin |
28 | 30 | from sklearn.base import TransformerMixin |
29 | 31 | from sklearn.base import clone |
30 | 32 | from sklearn.utils.metadata_requests import RequestType |
31 | 33 | from sklearn.utils.metadata_requests import metadata_request_factory |
32 | 34 | from sklearn.utils.metadata_requests import MetadataRouter |
33 | 35 | from sklearn.utils.validation import check_is_fitted |
| 36 | +from sklearn.linear_model import LinearRegression |
34 | 37 |
|
35 | 38 | N, M = 100, 4 |
36 | 39 | X = np.random.rand(N, M) |
@@ -519,3 +522,103 @@ def transform(self, X, bar=None): |
519 | 522 | ), |
520 | 523 | ) |
521 | 524 | est.fit(X, y, foo=my_weights, bar=my_groups).predict(X[:3], bar=my_groups) |
| 525 | + |
| 526 | +# %% |
| 527 | +# Deprechation / Default Value Change |
| 528 | +# ----------------------------------- |
| 529 | +# In this section we show how one should handle the case where a router becomes |
| 530 | +# also a consumer, especially when it consumes the same metadata as its |
| 531 | +# sub-estimator. In this case, a warning should be raised for a while, to let |
| 532 | +# users know the behavior is changed from previous versions. |
| 533 | + |
| 534 | + |
| 535 | +class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): |
| 536 | + def __init__(self, estimator): |
| 537 | + self.estimator = estimator |
| 538 | + |
| 539 | + def fit(self, X, y, **fit_params): |
| 540 | + metadata_request_factory(self).fit.validate_metadata( |
| 541 | + ignore_extras=False, self_metadata=super(), kwargs=fit_params |
| 542 | + ) |
| 543 | + fit_params_ = metadata_request_factory(self.estimator).fit.get_method_input( |
| 544 | + ignore_extras=False, kwargs=fit_params |
| 545 | + ) |
| 546 | + self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_) |
| 547 | + |
| 548 | + def get_metadata_request(self): |
| 549 | + router = MetadataRouter().add( |
| 550 | + self.estimator, mapping="one-to-one", overwrite=False, mask=True |
| 551 | + ) |
| 552 | + return router.get_metadata_request() |
| 553 | + |
| 554 | + |
| 555 | +# %% |
| 556 | +# As explained above, this is now a valid usage: |
| 557 | + |
| 558 | +reg = MetaRegressor(estimator=LinearRegression().fit_requests(sample_weight=True)) |
| 559 | +reg.fit(X, y, sample_weight=my_weights) |
| 560 | + |
| 561 | + |
| 562 | +# %% |
| 563 | +# Now imagine we further develop ``MetaRegressor`` and it now also *consumes* |
| 564 | +# ``sample_weight``: |
| 565 | + |
| 566 | + |
| 567 | +class SampledMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): |
| 568 | + __metadata_request__sample_weight = {"fit": {"sample_weight": RequestType.WARN}} |
| 569 | + |
| 570 | + def __init__(self, estimator): |
| 571 | + self.estimator = estimator |
| 572 | + |
| 573 | + def fit(self, X, y, sample_weight=None, **fit_params): |
| 574 | + if sample_weight is not None: |
| 575 | + fit_params["sample_weight"] = sample_weight |
| 576 | + metadata_request_factory(self).fit.validate_metadata( |
| 577 | + ignore_extras=False, self_metadata=super(), kwargs=fit_params |
| 578 | + ) |
| 579 | + estimator_fit_params = metadata_request_factory( |
| 580 | + self.estimator |
| 581 | + ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) |
| 582 | + self.estimator_ = clone(self.estimator).fit(X, y, **estimator_fit_params) |
| 583 | + |
| 584 | + def get_metadata_request(self): |
| 585 | + router = ( |
| 586 | + MetadataRouter() |
| 587 | + .add(super(), mapping="one-to-one", overwrite=False, mask=False) |
| 588 | + .add(self.estimator, mapping="one-to-one", overwrite="smart", mask=True) |
| 589 | + ) |
| 590 | + return router.get_metadata_request() |
| 591 | + |
| 592 | + |
| 593 | +# %% |
| 594 | +# The above implementation is almost no different than ``MetaRegressor``, and |
| 595 | +# because of the default request value defined in `__metadata_request__sample_weight`` |
| 596 | +# there is a warning raised. |
| 597 | + |
| 598 | +with warnings.catch_warnings(record=True) as record: |
| 599 | + SampledMetaRegressor( |
| 600 | + estimator=LinearRegression().fit_requests(sample_weight=False) |
| 601 | + ).fit(X, y, sample_weight=my_weights) |
| 602 | +for w in record: |
| 603 | + print(w.message) |
| 604 | + |
| 605 | + |
| 606 | +# %% |
| 607 | +# When an estimator suports a metadata which wasn't supported before, the |
| 608 | +# following pattern can be used to warn the users about it. |
| 609 | + |
| 610 | + |
| 611 | +class ExampleRegressor(RegressorMixin, BaseEstimator): |
| 612 | + __metadata_request__sample_weight = {"fit": {"sample_weight": RequestType.WARN}} |
| 613 | + |
| 614 | + def fit(self, X, y, sample_weight=None): |
| 615 | + return self |
| 616 | + |
| 617 | + def predict(self, X): |
| 618 | + return np.zeros(shape=(len(X))) |
| 619 | + |
| 620 | + |
| 621 | +with warnings.catch_warnings(record=True) as record: |
| 622 | + MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights) |
| 623 | +for w in record: |
| 624 | + print(w.message) |
0 commit comments