Skip to content

[XLA] tf.keras.layers.LSTM behaves differently on GPU #83063

@shaoyuyoung

Description

@shaoyuyoung

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

nightly

Custom code

Yes

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

When executing LSTM on XLA, it fails.
However, when executing it without XLA, it passes.
The above failure is on GPU.
If I use CPU as backend, with or without XLA both pass the check.

Standalone code to reproduce the issue

import os
import tensorflow
import tensorflow as tf
tf.random.set_seed(42)
class RecurrentModel(tf.keras.Model):

    def __init__(self):
        super(RecurrentModel, self).__init__()
        self.lstm = tf.keras.layers.LSTM(units=64, return_sequences=True)

    @tf.function(jit_compile=True)
    def call(self, x):
        return self.lstm(x)


model = RecurrentModel()


input_shape = (10, 20, 1)
x = tf.random.normal(shape=input_shape)

inputs = [x]

output = model(*inputs)
print(output)

Relevant log output

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-4-0938fdccd1fa> in <cell line: 24>()
     22 inputs = [x]
     23 
---> 24 output = model(*inputs)
     25 print(output)

1 frames
/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51   try:
     52     ctx.ensure_initialized()
---> 53     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                         inputs, attrs, num_outputs)
     55   except core._NotOkStatusException as e:

InvalidArgumentError: Exception encountered when calling RecurrentModel.call().

Detected unsupported operations when trying to compile graph __inference_call_877[_XlaMustCompile=true,config_proto=6001324581131673121,executor_type=11160318154034397263] on XLA_GPU_JIT: CudnnRNNV3 (No registered 'CudnnRNNV3' OpKernel for XLA_GPU_JIT devices compatible with node {{node lstm_3_1/CudnnRNNV3}}){{node lstm_3_1/CudnnRNNV3}}
The op is created at: 
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 685, in <lambda>
File "/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py", line 738, in _run_callback
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 825, in inner
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 786, in run
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 361, in process_one
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 261, in dispatch_shell
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 539, in execute_request
File "/usr/local/lib/python3.10/dist-packages/tornado/gen.py", line 234, in wrapper
File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 302, in do_execute
File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 539, in run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
File "<ipython-input-4-0938fdccd1fa>", line 24, in <cell line: 24>
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py", line 826, in __call__
File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py", line 1376, in _maybe_build
File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/core.py", line 212, in compute_output_spec
File "<ipython-input-1-0938fdccd1fa>", line 13, in call
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py", line 901, in __call__
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler
File "/usr/local/lib/python3.10/dist-packages/keras/src/ops/operation.py", line 46, in __call__
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 156, in error_handler
File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/rnn/lstm.py", line 570, in call
File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/rnn/rnn.py", line 406, in call
File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/rnn/lstm.py", line 537, in inner_loop
File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/rnn.py", line 841, in lstm
File "/usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/rnn.py", line 933, in _cudnn_lstm
	tf2xla conversion failed while converting __inference_call_877[_XlaMustCompile=true,config_proto=6001324581131673121,executor_type=11160318154034397263]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions. [Op:__inference_call_877]

Arguments received by RecurrentModel.call():
  • x=tf.Tensor(shape=(10, 20, 1), dtype=float32)

Metadata

Metadata

Assignees

Labels

TF 2.18comp:gpuGPU related issuescomp:xlaXLAstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authortype:bugBug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions