Skip to content

Keras scikit-learn wrapper not compatible with keras functional model #34689

@karimmohraz

Description

@karimmohraz

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • TensorFlow version (use command below): TF version: 2.0, 1.15
  • Python version: 3.6 , 3.7

Describe the current behavior
when using a keras functional api model via the keras scikit-learn wrapper a crash occurs.
see:

classes = self.model.predict_classes(x, **kwargs)

Describe the expected behavior
model.predict() should work on all keras model types besides sequential

Code to reproduce the issue

import numpy as np
import tensorflow as tf
import tensorflow.keras
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier

def build_model():
  input = tf.keras.layers.Input(shape=(2,))
  pred = tf.keras.layers.Dense(2, activation='softmax')(input)
  model = tf.keras.models.Model(inputs=input, outputs=pred)
  model.compile(loss='categorical_crossentropy', metrics=['accuracy'])
  return model

X = np.array([[1,2],[3,1]])
Y = np.array([[1,0], [0,1]])
model = build_model()
model.fit(X, Y)
print(model.predict(X))  # this works

model_wrapped = KerasClassifier(build_model)
model_wrapped.fit(X, Y)
model_wrapped.predict(X)  # this crashes

Output:
Train on 2 samples
2/2 [==============================] - 0s 62ms/sample - loss: 1.1024 - acc: 0.5000
[[0.62487346 0.37512657]
[0.8205698 0.17943017]]
Train on 2 samples
2/2 [==============================] - 0s 64ms/sample - loss: 0.2733 - acc: 1.0000

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-14-48bacae97b80> in <module>()
     19 model_wrapped = KerasClassifier(build_model)
     20 model_wrapped.fit(X, Y)
---> 21 model_wrapped.predict(X)  # this crashes
     22 
     23 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/wrappers/scikit_learn.py in predict(self, x, **kwargs)
    239     """
    240     kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
--> 241     classes = self.model.predict_classes(x, **kwargs)
    242     return self.classes_[classes]
    243 

AttributeError: 'Model' object has no attribute 'predict_classes'

Metadata

Metadata

Labels

TF 1.15for issues seen on TF 1.15TF 2.0Issues relating to TensorFlow 2.0comp:kerasKeras related issuesstat:awaiting responseStatus - Awaiting response from authortype:featureFeature requests

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions