-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
ENH added original version of pickled estimator in state dict #22094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
thomasjpfan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @Edern76 !
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @thomasjpfan suggestion above.
thomasjpfan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
sklearn/base.py
Outdated
| if ( | ||
| type(self).__module__.startswith("sklearn.") | ||
| and "_sklearn_pickle_version" not in state.keys() | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think _sklearn_pickle_version should be overridden every time. This way _sklearn_pickle_version means: "the runtime version of the model when it was pickled".
| if ( | |
| type(self).__module__.startswith("sklearn.") | |
| and "_sklearn_pickle_version" not in state.keys() | |
| ): | |
| if type(self).__module__.startswith("sklearn."): |
Let's say a model was trained on 0.24.2 and pickled. Then it is loaded into 1.0.2 and we correctly have _sklearn_pickle_version=0.24.2 and raise a warning. If the model was fitted again on 1.0.2, the state would be reset and require the 1.0.2 to run. This means _sklearn_pickle_version should be updated to 1.0.2 when pickled.
There is an edge case where a model was trained on on 0.24.2, pickled, loaded on 1.0.2, repickled again on 1.0.2 and a user expects _sklearn_pickle_version to be 0.24.2. I am considering this usage error.
sklearn/tests/test_base.py
Outdated
| "Use at your own risk. " | ||
| "For more info please refer to:\\n" | ||
| "https://scikit-learn.org/stable/modules/model_persistence" | ||
| ".html#security-maintainability-limitations" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @thomasjpfan's suggestion https://github.com/scikit-learn/scikit-learn/pull/22094/files#r777246976 is valid.
The test can be updated accordingly as follows:
sklearn/tests/test_base.py
Outdated
| def test_base_estimator_pickle_version(monkeypatch): | ||
| """Check that the original sklearn version with which a base estimator | ||
| has been pickled with is present""" | ||
| old_pickle_version = "0.21.3" | ||
| monkeypatch.setattr(sklearn.base, "__version__", old_pickle_version) | ||
| original_estimator = MyEstimator() | ||
|
|
||
| first_pickle_estimator = pickle.loads(pickle.dumps(original_estimator)) | ||
| assert hasattr(first_pickle_estimator, "_sklearn_pickle_version") | ||
| assert first_pickle_estimator._sklearn_pickle_version == old_pickle_version | ||
|
|
||
| new_pickle_version = "1.1.0" | ||
| monkeypatch.setattr(sklearn.base, "__version__", new_pickle_version) | ||
| message = pickle_error_message.format( | ||
| estimator="MyEstimator", | ||
| old_version=old_pickle_version, | ||
| current_version=new_pickle_version, | ||
| ) | ||
| with pytest.warns(UserWarning, match=message): | ||
| second_pickle_estimator = pickle.loads(pickle.dumps(first_pickle_estimator)) | ||
| assert hasattr(second_pickle_estimator, "_sklearn_pickle_version") | ||
| assert second_pickle_estimator._sklearn_pickle_version == old_pickle_version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main usage scenario we want to handle is rather:
| def test_base_estimator_pickle_version(monkeypatch): | |
| """Check that the original sklearn version with which a base estimator | |
| has been pickled with is present""" | |
| old_pickle_version = "0.21.3" | |
| monkeypatch.setattr(sklearn.base, "__version__", old_pickle_version) | |
| original_estimator = MyEstimator() | |
| first_pickle_estimator = pickle.loads(pickle.dumps(original_estimator)) | |
| assert hasattr(first_pickle_estimator, "_sklearn_pickle_version") | |
| assert first_pickle_estimator._sklearn_pickle_version == old_pickle_version | |
| new_pickle_version = "1.1.0" | |
| monkeypatch.setattr(sklearn.base, "__version__", new_pickle_version) | |
| message = pickle_error_message.format( | |
| estimator="MyEstimator", | |
| old_version=old_pickle_version, | |
| current_version=new_pickle_version, | |
| ) | |
| with pytest.warns(UserWarning, match=message): | |
| second_pickle_estimator = pickle.loads(pickle.dumps(first_pickle_estimator)) | |
| assert hasattr(second_pickle_estimator, "_sklearn_pickle_version") | |
| assert second_pickle_estimator._sklearn_pickle_version == old_pickle_version | |
| def test_base_estimator_pickle_version(monkeypatch): | |
| """The version should be embedded at dump time and checked at load time""" | |
| old_version = "0.21.3" | |
| monkeypatch.setattr(sklearn.base, "__version__", old_pickle_version) | |
| original_estimator = MyEstimator() | |
| old_pickle = pickle.dumps(original_estimator) | |
| loaded_estimator = pickle.loads(old_pickle) | |
| assert loaded_estimator._sklearn_pickle_version == old_version | |
| assert not hasattr(original_estimator, "_sklearn_pickle_version") | |
| new_version = "1.1.0" | |
| monkeypatch.setattr(sklearn.base, "__version__", new_version) | |
| message = pickle_error_message.format( | |
| estimator="MyEstimator", | |
| old_version=old_version, | |
| current_version=new_version, | |
| ) | |
| with pytest.warns(UserWarning, match=message): | |
| reloaded_estimator = pickle.loads(old_pickle) | |
| assert reloaded_estimator._sklearn_pickle_version == old_version |
Disclaimer: I have not run the code, there might be typos.
doc/whats_new/v1.1.rst
Outdated
| - |Enhancement| All scikit-learn estimators now include the sklearn version | ||
| with which they have first been pickled when saving them with the pickle. | ||
| library. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an updated changelog entry that would reflect the behavior with @thomasjpfan's suggested change.
| - |Enhancement| All scikit-learn estimators now include the sklearn version | |
| with which they have first been pickled when saving them with the pickle. | |
| library. | |
| - |Enhancement| All scikit-learn estimators now save the sklearn version | |
| with which they have been pickled on a private attribute to avoid having | |
| to parse the warning message to programmatically access this information | |
| to introspect this. |
It's a bit weird to document a pure-private API change but I am not sure what to do otherwise. Not documenting this change would even be worse I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to make the pickled version public. There is a need if users are parsing the warning message for this information.
sklearn/base.py
Outdated
| def __setstate__(self, state): | ||
| if type(self).__module__.startswith("sklearn."): | ||
| pickle_version = state.pop("_sklearn_version", "pre-0.18") | ||
| pickle_version = state.get("_sklearn_pickle_version", "pre-0.18") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's preserve backward compat with old but not too old pickles:
| pickle_version = state.get("_sklearn_pickle_version", "pre-0.18") | |
| pickle_version = state.pop("_sklearn_version", "pre-0.18") # compat | |
| pickle_version = state.setdefault("_sklearn_pickle_version", pickle_version) |
|
Thanks for the suggestions, I made a new commit with all these suggested changes |
thomasjpfan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment otherwise LGTM
Co-authored-by: Thomas J. Fan <[email protected]>
thomasjpfan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your patience, I still think this looks good.
doc/whats_new/v1.1.rst
Outdated
| error message when setting invalid hyper-parameters with `set_params`. | ||
| :pr:`21542` by :user:`Olivier Grisel <ogrisel>`. | ||
|
|
||
| - |Enhancement| All scikit-learn estimators now save the sklearn version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This did not make it into v1.1, which means this change log entry needs to move to v1.2.
sklearn/base.py
Outdated
| return dict(state.items(), _sklearn_version=__version__) | ||
| return dict( | ||
| state.items(), | ||
| _sklearn_pickle_version=__version__, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this again, I am starting to prefer making this more official by naming the attribute:
| _sklearn_pickle_version=__version__, | |
| __sklearn_pickle_version__=__version__, |
and then documenting this in https://scikit-learn.org/stable/model_persistence.html.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using custom dunder attribute names is explicitly discouraged by PEP8. I prefer the original _sklearn_pickle_version attribute name.
glemaitre
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I applied the changes proposed by @thomasjpfan and put my +1.
I will merge if the CIs turn green.
This PR is urgent because CI for our PRs will fail until there is a fix. After scikit-learn/scikit-learn#22094, sklearn estimators will contain an additional key in their __dict__ after loading, namely "__sklearn_pickle_version__". This causes our tests to fail, since they compare objects before and after loading. The quick solution is to pop the item in our tests if it exists and only compare the remaining items. Should the sklearn change be amended, we should remove the fix from this PR. Progress is tracked here: scikit-learn/scikit-learn#25273
After scikit-learn/scikit-learn#22094, sklearn estimators will contain an additional key in their __dict__ after loading, namely "__sklearn_pickle_version__". This causes our tests to fail, since they compare objects before and after loading. The quick solution is to pop the item in our tests if it exists and only compare the remaining items. Should the sklearn change be amended, we should remove the fix from this PR. Progress is tracked here: scikit-learn/scikit-learn#25273
…-learn#22094) Co-authored-by: Thomas J. Fan <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]>
Reference Issues/PRs
Fixes #22055
What does this implement/fix? Explain your changes.
As explained in the issue description, pickling an estimator saves the sklearn version with which it was last pickled but not the one with which it was first pickled.
This PR adds this information to the state dictionnary of BaseEstimator under the key
_sklearn_pickle_versionas suggested in the issue description. The value associated with this key is either the current version of sklearn if the key is not already present in the previous state dictionnary (i.e : if it has not been pickled before), or the previous value of that element if the key was present in the dictionnary (which allows us to "propagate" the version number of the first pickle in every other subsequent pickle)Any other comments?
This PR was realized as part of the APPC course at INSA Rouen.