Skip to content

[XLA] can't compile the tf.keras.layers.Conv2D when padding='valid' #84205

@shaoyuyoung

Description

@shaoyuyoung

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

nightly

Custom code

Yes

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

XLA can't compile the tf.keras.layers.Conv2D when padding='valid'. However, eager can pass the check.
There exists a misalignment

Standalone code to reproduce the issue

import os
import tensorflow as tf
tf.keras.utils.set_random_seed(42)
tf.random.set_seed(42)

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


x = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], dtype=tf.float32)
inputs = [x]



class Model(tf.keras.Model):

    def __init__(self):
        super(Model, self).__init__()
        self.conv = tf.keras.layers.Conv2D(filters=1, kernel_size=4, padding='valid', activation='relu')

    def call(self, x):
        x = tf.reshape(x, [1, 3, 3, 1])
        x = self.conv(x)
        return x


model = Model()
model(*inputs)
print("succeed on eager")



class Model(tf.keras.Model):

    def __init__(self):
        super(Model, self).__init__()
        self.conv = tf.keras.layers.Conv2D(filters=1, kernel_size=4, padding='valid', activation='relu')

    @tf.function(jit_compile=True)
    def call(self, x):
        x = tf.reshape(x, [1, 3, 3, 1])
        x = self.conv(x)
        return x


model = Model()
model(*inputs)
print("succeed on XLA")

Relevant log output

succeed on eager
Negative dimension size caused by subtracting 4 from 3 for '{{node conv2d_1_1/convolution}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], explicit_paddings=[], padding="VALID", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true](Reshape, conv2d_1_1/convolution/ReadVariableOp)' with input shapes: [1,3,3,1], [4,4,1,1].

Metadata

Metadata

Assignees

Labels

TF 2.18comp:xlaXLAstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:bugBug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions