Skip to content

Commit b5c962c

Browse files
committed
handling backward compatibility and deprecation prototype
1 parent 11649d9 commit b5c962c

File tree

2 files changed

+135
-6
lines changed

2 files changed

+135
-6
lines changed

examples/plot_metadata_routing.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,18 @@
2222
# %%
2323

2424
import numpy as np
25+
import warnings
2526
from sklearn.base import BaseEstimator
2627
from sklearn.base import ClassifierMixin
28+
from sklearn.base import RegressorMixin
2729
from sklearn.base import MetaEstimatorMixin
2830
from sklearn.base import TransformerMixin
2931
from sklearn.base import clone
3032
from sklearn.utils.metadata_requests import RequestType
3133
from sklearn.utils.metadata_requests import metadata_request_factory
3234
from sklearn.utils.metadata_requests import MetadataRouter
3335
from sklearn.utils.validation import check_is_fitted
36+
from sklearn.linear_model import LinearRegression
3437

3538
N, M = 100, 4
3639
X = np.random.rand(N, M)
@@ -519,3 +522,103 @@ def transform(self, X, bar=None):
519522
),
520523
)
521524
est.fit(X, y, foo=my_weights, bar=my_groups).predict(X[:3], bar=my_groups)
525+
526+
# %%
527+
# Deprechation / Default Value Change
528+
# -----------------------------------
529+
# In this section we show how one should handle the case where a router becomes
530+
# also a consumer, especially when it consumes the same metadata as its
531+
# sub-estimator. In this case, a warning should be raised for a while, to let
532+
# users know the behavior is changed from previous versions.
533+
534+
535+
class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
536+
def __init__(self, estimator):
537+
self.estimator = estimator
538+
539+
def fit(self, X, y, **fit_params):
540+
metadata_request_factory(self).fit.validate_metadata(
541+
ignore_extras=False, self_metadata=super(), kwargs=fit_params
542+
)
543+
fit_params_ = metadata_request_factory(self.estimator).fit.get_method_input(
544+
ignore_extras=False, kwargs=fit_params
545+
)
546+
self.estimator_ = clone(self.estimator).fit(X, y, **fit_params_)
547+
548+
def get_metadata_request(self):
549+
router = MetadataRouter().add(
550+
self.estimator, mapping="one-to-one", overwrite=False, mask=True
551+
)
552+
return router.get_metadata_request()
553+
554+
555+
# %%
556+
# As explained above, this is now a valid usage:
557+
558+
reg = MetaRegressor(estimator=LinearRegression().fit_requests(sample_weight=True))
559+
reg.fit(X, y, sample_weight=my_weights)
560+
561+
562+
# %%
563+
# Now imagine we further develop ``MetaRegressor`` and it now also *consumes*
564+
# ``sample_weight``:
565+
566+
567+
class SampledMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
568+
__metadata_request__sample_weight = {"fit": {"sample_weight": RequestType.WARN}}
569+
570+
def __init__(self, estimator):
571+
self.estimator = estimator
572+
573+
def fit(self, X, y, sample_weight=None, **fit_params):
574+
if sample_weight is not None:
575+
fit_params["sample_weight"] = sample_weight
576+
metadata_request_factory(self).fit.validate_metadata(
577+
ignore_extras=False, self_metadata=super(), kwargs=fit_params
578+
)
579+
estimator_fit_params = metadata_request_factory(
580+
self.estimator
581+
).fit.get_method_input(ignore_extras=True, kwargs=fit_params)
582+
self.estimator_ = clone(self.estimator).fit(X, y, **estimator_fit_params)
583+
584+
def get_metadata_request(self):
585+
router = (
586+
MetadataRouter()
587+
.add(super(), mapping="one-to-one", overwrite=False, mask=False)
588+
.add(self.estimator, mapping="one-to-one", overwrite="smart", mask=True)
589+
)
590+
return router.get_metadata_request()
591+
592+
593+
# %%
594+
# The above implementation is almost no different than ``MetaRegressor``, and
595+
# because of the default request value defined in `__metadata_request__sample_weight``
596+
# there is a warning raised.
597+
598+
with warnings.catch_warnings(record=True) as record:
599+
SampledMetaRegressor(
600+
estimator=LinearRegression().fit_requests(sample_weight=False)
601+
).fit(X, y, sample_weight=my_weights)
602+
for w in record:
603+
print(w.message)
604+
605+
606+
# %%
607+
# When an estimator suports a metadata which wasn't supported before, the
608+
# following pattern can be used to warn the users about it.
609+
610+
611+
class ExampleRegressor(RegressorMixin, BaseEstimator):
612+
__metadata_request__sample_weight = {"fit": {"sample_weight": RequestType.WARN}}
613+
614+
def fit(self, X, y, sample_weight=None):
615+
return self
616+
617+
def predict(self, X):
618+
return np.zeros(shape=(len(X)))
619+
620+
621+
with warnings.catch_warnings(record=True) as record:
622+
MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)
623+
for w in record:
624+
print(w.message)

sklearn/utils/metadata_requests.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from enum import Enum
33
from collections import defaultdict
44
from typing import Union, Optional
5+
from warnings import warn
56
from ..externals._sentinels import sentinel # type: ignore # mypy error!!!
67

78

@@ -13,6 +14,11 @@ class RequestType(Enum):
1314
# that a metadata is not present even though it may be present in the
1415
# corresponding method's signature.
1516
UNUSED = sentinel("UNUSED")
17+
# this sentinel is used whenever a default value is changed, and therefore
18+
# the user should explicitly set the value, otherwise a warning is shown.
19+
# An example is when a meta-estimator is only a router, but then becomes
20+
# also a consumer.
21+
WARN = sentinel("WARN")
1622

1723

1824
# this sentinel is the default used in `{method}_requests` methods to indicate
@@ -173,13 +179,14 @@ def add_request(
173179
if alias == RequestType.REQUESTED and current in {
174180
RequestType.ERROR_IF_PASSED,
175181
RequestType.UNREQUESTED,
182+
RequestType.WARN,
176183
}:
177184
self.requests[prop] = alias
178-
elif (
179-
alias == RequestType.UNREQUESTED
180-
and current == RequestType.ERROR_IF_PASSED
181-
):
182-
self.requests[prop] = alias
185+
elif alias in {RequestType.UNREQUESTED, RequestType.WARN} and current in {
186+
RequestType.ERROR_IF_PASSED,
187+
RequestType.WARN,
188+
}:
189+
self.requests[prop] = RequestType.UNREQUESTED
183190
elif self.requests[prop] != alias:
184191
raise ValueError(
185192
f"{prop} is already requested as {self.requests[prop]}, "
@@ -264,6 +271,17 @@ def validate_metadata(self, ignore_extras=False, self_metadata=None, kwargs=None
264271
self_metadata = getattr(
265272
metadata_request_factory(self_metadata), self.name
266273
).requests
274+
warn_metadata = [k for k, v in self_metadata.items() if v == RequestType.WARN]
275+
warn_kwargs = [k for k in kwargs.keys() if k in warn_metadata]
276+
if warn_kwargs:
277+
warn(
278+
"The following metadata are provided, which are now supported by this "
279+
f"class: {warn_kwargs}. These metadata were not processed in previous "
280+
"versions. Set their requested value to RequestType.UNREQUESTED "
281+
"to maintain previous behavior, or to RequestType.REQUESTED to "
282+
"consume and use the metadata.",
283+
UserWarning,
284+
)
267285
# we then remove self metadata from kwargs, since they should not be
268286
# validated.
269287
kwargs = {v: k for v, k in kwargs.items() if v not in self_metadata}
@@ -324,7 +342,15 @@ def get_method_input(self, ignore_extras=False, kwargs=None):
324342
if not isinstance(alias, str):
325343
alias = RequestType(alias)
326344

327-
if alias == RequestType.UNREQUESTED:
345+
if alias == RequestType.WARN:
346+
warn(
347+
f"Support for {prop} has recently been added to this class. "
348+
"To maintain backward compatibility, it is ignored now. "
349+
"You can set the request value to RequestType.UNREQUESTED "
350+
"to silence this warning, or to RequestType.REQUESTED to "
351+
"consume and use the metadata."
352+
)
353+
elif alias == RequestType.UNREQUESTED:
328354
continue
329355
elif alias == RequestType.REQUESTED and prop in args:
330356
res[prop] = args[prop]

0 commit comments

Comments
 (0)