Skip to content

Commit affaa62

Browse files
Edern76thomasjpfanglemaitre
authored
ENH added original version of pickled estimator in state dict (#22094)
Co-authored-by: Thomas J. Fan <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 5aa9b99 commit affaa62

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

doc/model_persistence.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ with::
5555
available `here
5656
<https://joblib.readthedocs.io/en/latest/persistence.html>`_.
5757

58+
Note that you can access to the attribute `__sklearn_pickle_version__` to check the
59+
version of scikit-learn used to pickle the model.
60+
5861
.. _persistence_limitations:
5962

6063
Security & maintainability limitations

doc/whats_new/v1.3.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ random sampling procedures.
3232
Changes impacting all modules
3333
-----------------------------
3434

35+
- |Enhancement| All scikit-learn estimators now save the scikit-learn version
36+
with which they have been pickled into a private attribute,
37+
`__sklearn_pickle_version__`. This allows access to this information without
38+
having to parse the warning message.
39+
:pr:`22094` by :user:`Gawein Le Goff <Edern76>`.
40+
3541
Changelog
3642
---------
3743

sklearn/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,13 +288,21 @@ def __getstate__(self):
288288
state = self.__dict__.copy()
289289

290290
if type(self).__module__.startswith("sklearn."):
291-
return dict(state.items(), _sklearn_version=__version__)
291+
return dict(
292+
state.items(),
293+
__sklearn_pickle_version__=__version__,
294+
)
292295
else:
293296
return state
294297

295298
def __setstate__(self, state):
296299
if type(self).__module__.startswith("sklearn."):
300+
# Before 1.3, `_sklearn_version` was used to store the sklearn version
301+
# when the estimator was pickled
297302
pickle_version = state.pop("_sklearn_version", "pre-0.18")
303+
pickle_version = state.setdefault(
304+
"__sklearn_pickle_version__", pickle_version
305+
)
298306
if pickle_version != __version__:
299307
warnings.warn(
300308
"Trying to unpickle estimator {0} from version {1} when "

sklearn/tests/test_base.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -376,15 +376,18 @@ def test_pickle_version_warning_is_not_raised_with_matching_version():
376376

377377
class TreeBadVersion(DecisionTreeClassifier):
378378
def __getstate__(self):
379-
return dict(self.__dict__.items(), _sklearn_version="something")
379+
return dict(self.__dict__.items(), __sklearn_pickle_version__="something")
380380

381381

382382
pickle_error_message = (
383383
"Trying to unpickle estimator {estimator} from "
384384
"version {old_version} when using version "
385385
"{current_version}. This might "
386386
"lead to breaking code or invalid results. "
387-
"Use at your own risk."
387+
"Use at your own risk. "
388+
"For more info please refer to:\n"
389+
"https://scikit-learn.org/stable/model_persistence.html"
390+
"#security-maintainability-limitations"
388391
)
389392

390393

@@ -397,7 +400,7 @@ def test_pickle_version_warning_is_issued_upon_different_version():
397400
old_version="something",
398401
current_version=sklearn.__version__,
399402
)
400-
with pytest.warns(UserWarning, match=message):
403+
with pytest.warns(UserWarning, match=re.escape(message)):
401404
pickle.loads(tree_pickle_other)
402405

403406

@@ -419,7 +422,7 @@ def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
419422
current_version=sklearn.__version__,
420423
)
421424
# check we got the warning about using pre-0.18 pickle
422-
with pytest.warns(UserWarning, match=message):
425+
with pytest.warns(UserWarning, match=re.escape(message)):
423426
pickle.loads(tree_pickle_noversion)
424427

425428

@@ -666,6 +669,29 @@ def transform(self, X):
666669
trans.transform(df_mixed)
667670

668671

672+
def test_base_estimator_pickle_version(monkeypatch):
673+
"""Check the version is embedded when pickled and checked when unpickled."""
674+
old_version = "0.21.3"
675+
monkeypatch.setattr(sklearn.base, "__version__", old_version)
676+
original_estimator = MyEstimator()
677+
old_pickle = pickle.dumps(original_estimator)
678+
loaded_estimator = pickle.loads(old_pickle)
679+
assert loaded_estimator.__sklearn_pickle_version__ == old_version
680+
assert not hasattr(original_estimator, "__sklearn_pickle_version__")
681+
682+
# Loading pickle with newer version will raise a warning
683+
new_version = "1.3.0"
684+
monkeypatch.setattr(sklearn.base, "__version__", new_version)
685+
message = pickle_error_message.format(
686+
estimator="MyEstimator",
687+
old_version=old_version,
688+
current_version=new_version,
689+
)
690+
with pytest.warns(UserWarning, match=re.escape(message)):
691+
reloaded_estimator = pickle.loads(old_pickle)
692+
assert reloaded_estimator.__sklearn_pickle_version__ == old_version
693+
694+
669695
def test_clone_keeps_output_config():
670696
"""Check that clone keeps the set_output config."""
671697

@@ -694,7 +720,7 @@ def test_estimator_empty_instance_dict(estimator):
694720
``AttributeError``. Non-regression test for gh-25188.
695721
"""
696722
state = estimator.__getstate__()
697-
expected = {"_sklearn_version": sklearn.__version__}
723+
expected = {"__sklearn_pickle_version__": sklearn.__version__}
698724
assert state == expected
699725

700726
# this should not raise

0 commit comments

Comments
 (0)