@@ -376,15 +376,18 @@ def test_pickle_version_warning_is_not_raised_with_matching_version():
376376
377377class TreeBadVersion (DecisionTreeClassifier ):
378378 def __getstate__ (self ):
379- return dict (self .__dict__ .items (), _sklearn_version = "something" )
379+ return dict (self .__dict__ .items (), __sklearn_pickle_version__ = "something" )
380380
381381
382382pickle_error_message = (
383383 "Trying to unpickle estimator {estimator} from "
384384 "version {old_version} when using version "
385385 "{current_version}. This might "
386386 "lead to breaking code or invalid results. "
387- "Use at your own risk."
387+ "Use at your own risk. "
388+ "For more info please refer to:\n "
389+ "https://scikit-learn.org/stable/model_persistence.html"
390+ "#security-maintainability-limitations"
388391)
389392
390393
@@ -397,7 +400,7 @@ def test_pickle_version_warning_is_issued_upon_different_version():
397400 old_version = "something" ,
398401 current_version = sklearn .__version__ ,
399402 )
400- with pytest .warns (UserWarning , match = message ):
403+ with pytest .warns (UserWarning , match = re . escape ( message ) ):
401404 pickle .loads (tree_pickle_other )
402405
403406
@@ -419,7 +422,7 @@ def test_pickle_version_warning_is_issued_when_no_version_info_in_pickle():
419422 current_version = sklearn .__version__ ,
420423 )
421424 # check we got the warning about using pre-0.18 pickle
422- with pytest .warns (UserWarning , match = message ):
425+ with pytest .warns (UserWarning , match = re . escape ( message ) ):
423426 pickle .loads (tree_pickle_noversion )
424427
425428
@@ -666,6 +669,29 @@ def transform(self, X):
666669 trans .transform (df_mixed )
667670
668671
672+ def test_base_estimator_pickle_version (monkeypatch ):
673+ """Check the version is embedded when pickled and checked when unpickled."""
674+ old_version = "0.21.3"
675+ monkeypatch .setattr (sklearn .base , "__version__" , old_version )
676+ original_estimator = MyEstimator ()
677+ old_pickle = pickle .dumps (original_estimator )
678+ loaded_estimator = pickle .loads (old_pickle )
679+ assert loaded_estimator .__sklearn_pickle_version__ == old_version
680+ assert not hasattr (original_estimator , "__sklearn_pickle_version__" )
681+
682+ # Loading pickle with newer version will raise a warning
683+ new_version = "1.3.0"
684+ monkeypatch .setattr (sklearn .base , "__version__" , new_version )
685+ message = pickle_error_message .format (
686+ estimator = "MyEstimator" ,
687+ old_version = old_version ,
688+ current_version = new_version ,
689+ )
690+ with pytest .warns (UserWarning , match = re .escape (message )):
691+ reloaded_estimator = pickle .loads (old_pickle )
692+ assert reloaded_estimator .__sklearn_pickle_version__ == old_version
693+
694+
669695def test_clone_keeps_output_config ():
670696 """Check that clone keeps the set_output config."""
671697
@@ -694,7 +720,7 @@ def test_estimator_empty_instance_dict(estimator):
694720 ``AttributeError``. Non-regression test for gh-25188.
695721 """
696722 state = estimator .__getstate__ ()
697- expected = {"_sklearn_version " : sklearn .__version__ }
723+ expected = {"__sklearn_pickle_version__ " : sklearn .__version__ }
698724 assert state == expected
699725
700726 # this should not raise
0 commit comments