Skip to content

Conversation

@Edern76
Copy link
Contributor

@Edern76 Edern76 commented Dec 29, 2021

Reference Issues/PRs

Fixes #22055

What does this implement/fix? Explain your changes.

As explained in the issue description, pickling an estimator saves the sklearn version with which it was last pickled but not the one with which it was first pickled.

This PR adds this information to the state dictionnary of BaseEstimator under the key _sklearn_pickle_version as suggested in the issue description. The value associated with this key is either the current version of sklearn if the key is not already present in the previous state dictionnary (i.e : if it has not been pickled before), or the previous value of that element if the key was present in the dictionnary (which allows us to "propagate" the version number of the first pickle in every other subsequent pickle)

Any other comments?

This PR was realized as part of the APPC course at INSA Rouen.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @Edern76 !

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @thomasjpfan suggestion above.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

sklearn/base.py Outdated
Comment on lines 321 to 324
if (
type(self).__module__.startswith("sklearn.")
and "_sklearn_pickle_version" not in state.keys()
):
Copy link
Member

@thomasjpfan thomasjpfan Jan 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think _sklearn_pickle_version should be overridden every time. This way _sklearn_pickle_version means: "the runtime version of the model when it was pickled".

Suggested change
if (
type(self).__module__.startswith("sklearn.")
and "_sklearn_pickle_version" not in state.keys()
):
if type(self).__module__.startswith("sklearn."):

Let's say a model was trained on 0.24.2 and pickled. Then it is loaded into 1.0.2 and we correctly have _sklearn_pickle_version=0.24.2 and raise a warning. If the model was fitted again on 1.0.2, the state would be reset and require the 1.0.2 to run. This means _sklearn_pickle_version should be updated to 1.0.2 when pickled.

There is an edge case where a model was trained on on 0.24.2, pickled, loaded on 1.0.2, repickled again on 1.0.2 and a user expects _sklearn_pickle_version to be 0.24.2. I am considering this usage error.

"Use at your own risk. "
"For more info please refer to:\\n"
"https://scikit-learn.org/stable/modules/model_persistence"
".html#security-maintainability-limitations"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 699 to 720
def test_base_estimator_pickle_version(monkeypatch):
"""Check that the original sklearn version with which a base estimator
has been pickled with is present"""
old_pickle_version = "0.21.3"
monkeypatch.setattr(sklearn.base, "__version__", old_pickle_version)
original_estimator = MyEstimator()

first_pickle_estimator = pickle.loads(pickle.dumps(original_estimator))
assert hasattr(first_pickle_estimator, "_sklearn_pickle_version")
assert first_pickle_estimator._sklearn_pickle_version == old_pickle_version

new_pickle_version = "1.1.0"
monkeypatch.setattr(sklearn.base, "__version__", new_pickle_version)
message = pickle_error_message.format(
estimator="MyEstimator",
old_version=old_pickle_version,
current_version=new_pickle_version,
)
with pytest.warns(UserWarning, match=message):
second_pickle_estimator = pickle.loads(pickle.dumps(first_pickle_estimator))
assert hasattr(second_pickle_estimator, "_sklearn_pickle_version")
assert second_pickle_estimator._sklearn_pickle_version == old_pickle_version
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main usage scenario we want to handle is rather:

Suggested change
def test_base_estimator_pickle_version(monkeypatch):
"""Check that the original sklearn version with which a base estimator
has been pickled with is present"""
old_pickle_version = "0.21.3"
monkeypatch.setattr(sklearn.base, "__version__", old_pickle_version)
original_estimator = MyEstimator()
first_pickle_estimator = pickle.loads(pickle.dumps(original_estimator))
assert hasattr(first_pickle_estimator, "_sklearn_pickle_version")
assert first_pickle_estimator._sklearn_pickle_version == old_pickle_version
new_pickle_version = "1.1.0"
monkeypatch.setattr(sklearn.base, "__version__", new_pickle_version)
message = pickle_error_message.format(
estimator="MyEstimator",
old_version=old_pickle_version,
current_version=new_pickle_version,
)
with pytest.warns(UserWarning, match=message):
second_pickle_estimator = pickle.loads(pickle.dumps(first_pickle_estimator))
assert hasattr(second_pickle_estimator, "_sklearn_pickle_version")
assert second_pickle_estimator._sklearn_pickle_version == old_pickle_version
def test_base_estimator_pickle_version(monkeypatch):
"""The version should be embedded at dump time and checked at load time"""
old_version = "0.21.3"
monkeypatch.setattr(sklearn.base, "__version__", old_pickle_version)
original_estimator = MyEstimator()
old_pickle = pickle.dumps(original_estimator)
loaded_estimator = pickle.loads(old_pickle)
assert loaded_estimator._sklearn_pickle_version == old_version
assert not hasattr(original_estimator, "_sklearn_pickle_version")
new_version = "1.1.0"
monkeypatch.setattr(sklearn.base, "__version__", new_version)
message = pickle_error_message.format(
estimator="MyEstimator",
old_version=old_version,
current_version=new_version,
)
with pytest.warns(UserWarning, match=message):
reloaded_estimator = pickle.loads(old_pickle)
assert reloaded_estimator._sklearn_pickle_version == old_version

Disclaimer: I have not run the code, there might be typos.

Comment on lines 73 to 75
- |Enhancement| All scikit-learn estimators now include the sklearn version
with which they have first been pickled when saving them with the pickle.
library.
Copy link
Member

@ogrisel ogrisel Jan 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an updated changelog entry that would reflect the behavior with @thomasjpfan's suggested change.

Suggested change
- |Enhancement| All scikit-learn estimators now include the sklearn version
with which they have first been pickled when saving them with the pickle.
library.
- |Enhancement| All scikit-learn estimators now save the sklearn version
with which they have been pickled on a private attribute to avoid having
to parse the warning message to programmatically access this information
to introspect this.

It's a bit weird to document a pure-private API change but I am not sure what to do otherwise. Not documenting this change would even be worse I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to make the pickled version public. There is a need if users are parsing the warning message for this information.

sklearn/base.py Outdated
def __setstate__(self, state):
if type(self).__module__.startswith("sklearn."):
pickle_version = state.pop("_sklearn_version", "pre-0.18")
pickle_version = state.get("_sklearn_pickle_version", "pre-0.18")
Copy link
Member

@ogrisel ogrisel Jan 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's preserve backward compat with old but not too old pickles:

Suggested change
pickle_version = state.get("_sklearn_pickle_version", "pre-0.18")
pickle_version = state.pop("_sklearn_version", "pre-0.18") # compat
pickle_version = state.setdefault("_sklearn_pickle_version", pickle_version)

@Edern76
Copy link
Contributor Author

Edern76 commented Jan 8, 2022

Thanks for the suggestions, I made a new commit with all these suggested changes

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment otherwise LGTM

Co-authored-by: Thomas J. Fan <[email protected]>
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your patience, I still think this looks good.

error message when setting invalid hyper-parameters with `set_params`.
:pr:`21542` by :user:`Olivier Grisel <ogrisel>`.

- |Enhancement| All scikit-learn estimators now save the sklearn version
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This did not make it into v1.1, which means this change log entry needs to move to v1.2.

sklearn/base.py Outdated
return dict(state.items(), _sklearn_version=__version__)
return dict(
state.items(),
_sklearn_pickle_version=__version__,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, I am starting to prefer making this more official by naming the attribute:

Suggested change
_sklearn_pickle_version=__version__,
__sklearn_pickle_version__=__version__,

and then documenting this in https://scikit-learn.org/stable/model_persistence.html.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using custom dunder attribute names is explicitly discouraged by PEP8. I prefer the original _sklearn_pickle_version attribute name.

@cmarmo cmarmo added Waiting for Second Reviewer First reviewer is done, need a second one! and removed Waiting for Reviewer labels Oct 20, 2022
@glemaitre glemaitre self-requested a review December 28, 2022 17:19
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I applied the changes proposed by @thomasjpfan and put my +1.
I will merge if the CIs turn green.

@glemaitre glemaitre enabled auto-merge (squash) December 28, 2022 18:12
@glemaitre glemaitre merged commit affaa62 into scikit-learn:main Dec 28, 2022
BenjaminBossan added a commit to BenjaminBossan/skops that referenced this pull request Jan 2, 2023
This PR is urgent because CI for our PRs will fail until there is a fix.

After scikit-learn/scikit-learn#22094, sklearn
estimators will contain an additional key in their __dict__ after
loading, namely "__sklearn_pickle_version__". This causes our tests to
fail, since they compare objects before and after loading.

The quick solution is to pop the item in our tests if it exists and only
compare the remaining items.

Should the sklearn change be amended, we should remove the fix from this
PR. Progress is tracked here:

scikit-learn/scikit-learn#25273
adrinjalali pushed a commit to skops-dev/skops that referenced this pull request Jan 2, 2023
After scikit-learn/scikit-learn#22094, sklearn
estimators will contain an additional key in their __dict__ after
loading, namely "__sklearn_pickle_version__". This causes our tests to
fail, since they compare objects before and after loading.

The quick solution is to pop the item in our tests if it exists and only
compare the remaining items.

Should the sklearn change be amended, we should remove the fix from this
PR. Progress is tracked here:

scikit-learn/scikit-learn#25273
jjerphan pushed a commit to jjerphan/scikit-learn that referenced this pull request Jan 3, 2023
glemaitre added a commit that referenced this pull request Jan 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Waiting for Second Reviewer First reviewer is done, need a second one!

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add an attribute to determine which version of sklearn a model was pickled with

5 participants