Skip to content

torch.compile fails when compiling a T5-style model with HF interfaces #96130

@ani300

Description

@ani300

🐛 Describe the bug

When using a Huggingface T5 model to generate text, one usually uses generate(). This function ends up calling the forward() method of the model with a set of options that result in the function returning a Seq2SeqLMOutput(). If this dataclass only has 1 argument passed into it, torch.compile() fails. If more than 1 argument is passed into the constructor, the code works as expected.

cc: @HamidShojanazeri @ezyang @raghukiran1224 @mudhakar

Error logs

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 619, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 583, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 349, in wrapper
    return inner_fn(self, inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1063, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 517, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 127, in call_function
    return variables.DataClassVariable.create(self.value, args, kwargs, options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/variables/dicts.py", line 344, in create
    if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
KeyError: loss

from user code:
   File "/workspace/foundation-model-stack/nlp/scripts/inference/torch_compile_repro.py", line 8, in forward
    return Seq2SeqLMOutput(logits=inputs)

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/foundation-model-stack/nlp/scripts/inference/torch_compile_repro.py", line 15, in <module>
    main()
  File "/workspace/foundation-model-stack/nlp/scripts/inference/torch_compile_repro.py", line 13, in main
    model(torch.tensor([0.1, 0.2]))
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 95, in __call__
    return self.dynamo_ctx(self._orig_mod.__call__)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 368, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 394, in _compile
    raise InternalTorchDynamoError() from e
torch._dynamo.exc.InternalTorchDynamoError

Minified repro

python mini_repro.py

import torch
import torch.nn as nn
from transformers.modeling_outputs import Seq2SeqLMOutput

class ReproError(nn.Module):

    def forward(self, inputs):
        return Seq2SeqLMOutput(logits=inputs)


def main():
    model = torch.compile(ReproError())
    model(torch.tensor([0.1, 0.2]))

main()

Versions

Collecting environment information...
PyTorch version: 2.1.0.dev20230304+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.25.2
Libc version: glibc-2.31

Python version: 3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-4.18.0-372.19.1.el8_6.x86_64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 515.48.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 80
On-line CPU(s) list: 0-79
Thread(s) per core: 2
Core(s) per socket: 20
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel Xeon Processor (Cascadelake)
Stepping: 5
CPU MHz: 2399.998
BogoMIPS: 4799.99
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.3 MiB
L1i cache: 1.3 MiB
L2 cache: 160 MiB
L3 cache: 32 MiB
NUMA node0 CPU(s): 0-39
NUMA node1 CPU(s): 40-79
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat pku ospke avx512_vnni md_clear arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] pytorch-triton==2.0.0+b8b470bc59
[pip3] torch==2.1.0.dev20230304+cu117
[pip3] torchvision==0.15.0.dev20230304+cu117
[conda] numpy 1.24.2 pypi_0 pypi
[conda] pytorch-triton 2.0.0+b8b470bc59 pypi_0 pypi
[conda] torch 2.1.0.dev20230304+cu117 pypi_0 pypi
[conda] torchvision 0.15.0.dev20230304+cu117 pypi_0 pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @davidberard98

Metadata

Metadata

Assignees

Labels

module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions