Skip to content

Cannot pickle predict_proba function with 1.0 release #21344

@imatiach-msft

Description

@imatiach-msft

Describe the bug

When trying to pickle a scikit learn predict_proba function, I now see the error in the latest release:

_pickle.PicklingError: Can't pickle <function BaseSVC.predict_proba at 0x000001F3460AAEE8>: it's not the same object as sklearn.svm._base.BaseSVC.predict_proba

This is probably due to this PR:
#19948

specifically, I believe this is because we return a lambda now here, which can no longer be pickled:

return lambda fn: _AvailableIfDescriptor(fn, check, attribute_name=fn.__name__)

This can be easily fixed by turning it into a function in the file.

I suppose there is a philosophical question of whether we should be able to pickle functions at all. I think we should. But it's probably not as important as pickling models. In any case this should be a simple and easy fix.

Steps/Code to Reproduce

from joblib import dump, load
from sklearn import svm
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

def create_scikit_cancer_data():
    breast_cancer_data = load_breast_cancer()
    classes = breast_cancer_data.target_names.tolist()
    x_train, x_test, y_train, y_test = train_test_split(breast_cancer_data.data, breast_cancer_data.target, test_size=0.2, random_state=0)
    feature_names = breast_cancer_data.feature_names
    classes = breast_cancer_data.target_names.tolist()
    return x_train, x_test, y_train, y_test, feature_names, classes

def create_sklearn_svm_classifier(X, y, probability=True):
    clf = svm.SVC(gamma=0.001, C=100., probability=probability, random_state=777)
    model = clf.fit(X, y)
    return model

x_train, x_test, y_train, _, feature_names, target_names = create_scikit_cancer_data()
model = create_sklearn_svm_classifier(x_train, y_train)
with open('pickle_model_function', 'wb') as stream:
    dump(model.predict_proba, stream)

Expected Results

We should be able to pickle the function

Actual Results

>>> from joblib import dump, load
>>> from sklearn import svm
>>> from sklearn.datasets import load_breast_cancer
>>> from sklearn.model_selection import train_test_split
>>>
>>> def create_scikit_cancer_data():
...     breast_cancer_data = load_breast_cancer()
...     classes = breast_cancer_data.target_names.tolist()
...     x_train, x_test, y_train, y_test = train_test_split(breast_cancer_data.data, breast_cancer_data.target, test_size=0.2, random_state=0)
...     feature_names = breast_cancer_data.feature_names
...     classes = breast_cancer_data.target_names.tolist()
...     return x_train, x_test, y_train, y_test, feature_names, classes
...
>>> def create_sklearn_svm_classifier(X, y, probability=True):
...     clf = svm.SVC(gamma=0.001, C=100., probability=probability, random_state=777)
...     model = clf.fit(X, y)
...     return model
...
>>> x_train, x_test, y_train, _, feature_names, target_names = create_scikit_cancer_data()
>>> model = create_sklearn_svm_classifier(x_train, y_train)
>>> with open('pickle_model_function', 'wb') as stream:
...     dump(model.predict_proba, stream)
...
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "C:\Users\ilmat\AppData\Local\Continuum\Miniconda3\envs\test\lib\site-packages\joblib\numpy_pickle.py", line 482, in dump
    NumpyPickler(filename, protocol=protocol).dump(value)
  File "C:\Users\ilmat\AppData\Local\Continuum\Miniconda3\envs\test\lib\pickle.py", line 437, in dump
    self.save(obj)
  File "C:\Users\ilmat\AppData\Local\Continuum\Miniconda3\envs\test\lib\site-packages\joblib\numpy_pickle.py", line 282, in save
    return Pickler.save(self, obj)
  File "C:\Users\ilmat\AppData\Local\Continuum\Miniconda3\envs\test\lib\pickle.py", line 504, in save
    f(self, obj) # Call unbound method with explicit self
  File "C:\Users\ilmat\AppData\Local\Continuum\Miniconda3\envs\test\lib\pickle.py", line 965, in save_global
    (obj, module_name, name))
_pickle.PicklingError: Can't pickle <function BaseSVC.predict_proba at 0x000002324C572CA8>: it's not the same object as sklearn.svm._base.BaseSVC.predict_proba

Versions

this only happens with latest 1.0 release and it broke our tests/builds, I'm trying to work around it by pickling the model instead here:
interpretml/interpret-community#455

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions