Skip to content

MultiheadAttention is not scriptable #20722

@ngimel

Description

@ngimel

🐛 Bug

When trying to use MultiheadAttention in torch.jit.ScriptModule, errors result

To Reproduce

import torch
import torch.nn as nn


class MyModule(torch.jit.ScriptModule):
#class MyModule(nn.Module):
   def __init__(self,embed_dim, num_heads):
       super(MyModule, self).__init__()
       self.mod = nn.MultiheadAttention(embed_dim, num_heads)

   def forward(self, q,k,v):
       return self.mod(q,k,v)


embed_dim = 1024
num_heads = 16
sl=30
bs=20
model = MyModule(embed_dim, num_heads).cuda()
q=torch.randn(sl,bs,embed_dim, device="cuda")
k=torch.randn(sl,bs,embed_dim, device="cuda")
v=torch.randn(sl,bs,embed_dim, device="cuda")

out = model(q,k,v)
print(out[0].size())
Traceback (most recent call last):
  File "/workspace/ALL/playground/attentionscripting.py", line 19, in <module>
    model = MyModule(embed_dim, num_heads).cuda()
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1202, in init_then_register
    original_init(self, *args, **kwargs)
  File "/workspace/ALL/playground/attentionscripting.py", line 9, in __init__
    self.attn = nn.MultiheadAttention(embed_dim, num_heads)
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1397, in __setattr__
    value = _make_strong(value)
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1586, in _make_strong
    proxy = weak_type(mod, stubs)
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1202, in init_then_register
    original_init(self, *args, **kwargs)
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1202, in init_then_register
    original_init(self, *args, **kwargs)
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1512, in __init__
    _create_methods_from_stubs(self, stubs)
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1163, in _create_methods_from_stubs
    self._c._create_methods(self, defs, rcbs, defaults)
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 898, in _try_compile_weak_script
    compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
  File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 974, in script
    ast = get_jit_def(obj)
  File "/workspace/ALL/pytorch_upstream/torch/jit/frontend.py", line 156, in get_jit_def
    type_line = torch.jit.annotations.get_type_line(source)
  File "/workspace/ALL/pytorch_upstream/torch/jit/annotations.py", line 136, in get_type_line
    raise RuntimeError("Return type line '# type: (...) -> ...' not found on multiline "
RuntimeError: Return type line '# type: (...) -> ...' not found on multiline type annotation

Expected behavior

MultiheadAttention module can be used in ScriptModule. After the error above is resolved, the next one I think will be .new_zeros not being scriptable.

Environment

Fresh compile from master (09f22d1) python 3.6

cc @zhangguanheng66

Metadata

Metadata

Labels

module: nnRelated to torch.nnoncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions