Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/model_persistence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ with::
available `here
<https://joblib.readthedocs.io/en/latest/persistence.html>`_.

Note that you can access to the attribute `__sklearn_pickle_version__` to check the
version of scikit-learn used to pickle the model.

.. _persistence_limitations:

Security & maintainability limitations
Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ random sampling procedures.
Changes impacting all modules
-----------------------------

- |Enhancement| All scikit-learn estimators now save the scikit-learn version
with which they have been pickled into a private attribute,
`__sklearn_pickle_version__`. This allows access to this information without
having to parse the warning message.
:pr:`22094` by :user:`Gawein Le Goff <Edern76>`.

Changelog
---------

Expand Down
10 changes: 9 additions & 1 deletion sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,21 @@ def __getstate__(self):
state = self.__dict__.copy()

if type(self).__module__.startswith("sklearn."):
return dict(state.items(), _sklearn_version=__version__)
return dict(
state.items(),
__sklearn_pickle_version__=__version__,
)
else:
return state

def __setstate__(self, state):
if type(self).__module__.startswith("sklearn."):
# Before 1.3, `_sklearn_version` was used to store the sklearn version
# when the estimator was pickled
pickle_version = state.pop("_sklearn_version", "pre-0.18")
pickle_version = state.setdefault(
"__sklearn_pickle_version__", pickle_version
)
if pickle_version != __version__:
warnings.warn(
"Trying to unpickle estimator {0} from version {1} when "
Expand Down
36 changes: 31 additions & 5 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,18 @@ def test_pickle_version_warning_is_not_raised_with_matching_version():

class TreeBadVersion(DecisionTreeClassifier):
def __getstate__(self):
return dict(self.__dict__.items(), _sklearn_version="something")
return dict(self.__dict__.items(), __sklearn_pickle_version__="something")


pickle_error_message = (
"Trying to unpickle estimator {estimator} from "
"version {old_version} when using version "
"{current_version}. This might "
"lead to breaking code or invalid results. "
"Use at your own risk."
"Use at your own risk. "
"For more info please refer to:\n"
"https://scikit-learn.org/stable/model_persistence.html"
"#security-maintainability-limitations"
)


Expand All @@ -397,7 +400,7 @@ def test_pickle_version_warning_is_issued_upon_different_version():
old_version="something",
current_version=sklearn.__version__,
)
with pytest.warns(UserWarning, match=message):
with pytest.warns(UserWarning, match=re.escape(message)):
pickle.loads(tree_pickle_other)


Expand All @@ -419,7 +422,7 @@ def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
current_version=sklearn.__version__,
)
# check we got the warning about using pre-0.18 pickle
with pytest.warns(UserWarning, match=message):
with pytest.warns(UserWarning, match=re.escape(message)):
pickle.loads(tree_pickle_noversion)


Expand Down Expand Up @@ -666,6 +669,29 @@ def transform(self, X):
trans.transform(df_mixed)


def test_base_estimator_pickle_version(monkeypatch):
"""Check the version is embedded when pickled and checked when unpickled."""
old_version = "0.21.3"
monkeypatch.setattr(sklearn.base, "__version__", old_version)
original_estimator = MyEstimator()
old_pickle = pickle.dumps(original_estimator)
loaded_estimator = pickle.loads(old_pickle)
assert loaded_estimator.__sklearn_pickle_version__ == old_version
assert not hasattr(original_estimator, "__sklearn_pickle_version__")

# Loading pickle with newer version will raise a warning
new_version = "1.3.0"
monkeypatch.setattr(sklearn.base, "__version__", new_version)
message = pickle_error_message.format(
estimator="MyEstimator",
old_version=old_version,
current_version=new_version,
)
with pytest.warns(UserWarning, match=re.escape(message)):
reloaded_estimator = pickle.loads(old_pickle)
assert reloaded_estimator.__sklearn_pickle_version__ == old_version


def test_clone_keeps_output_config():
"""Check that clone keeps the set_output config."""

Expand Down Expand Up @@ -694,7 +720,7 @@ def test_estimator_empty_instance_dict(estimator):
``AttributeError``. Non-regression test for gh-25188.
"""
state = estimator.__getstate__()
expected = {"_sklearn_version": sklearn.__version__}
expected = {"__sklearn_pickle_version__": sklearn.__version__}
assert state == expected

# this should not raise
Expand Down