-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: nnRelated to torch.nnRelated to torch.nnoncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Bug
When trying to use MultiheadAttention in torch.jit.ScriptModule, errors result
To Reproduce
import torch
import torch.nn as nn
class MyModule(torch.jit.ScriptModule):
#class MyModule(nn.Module):
def __init__(self,embed_dim, num_heads):
super(MyModule, self).__init__()
self.mod = nn.MultiheadAttention(embed_dim, num_heads)
def forward(self, q,k,v):
return self.mod(q,k,v)
embed_dim = 1024
num_heads = 16
sl=30
bs=20
model = MyModule(embed_dim, num_heads).cuda()
q=torch.randn(sl,bs,embed_dim, device="cuda")
k=torch.randn(sl,bs,embed_dim, device="cuda")
v=torch.randn(sl,bs,embed_dim, device="cuda")
out = model(q,k,v)
print(out[0].size())
Traceback (most recent call last):
File "/workspace/ALL/playground/attentionscripting.py", line 19, in <module>
model = MyModule(embed_dim, num_heads).cuda()
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1202, in init_then_register
original_init(self, *args, **kwargs)
File "/workspace/ALL/playground/attentionscripting.py", line 9, in __init__
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1397, in __setattr__
value = _make_strong(value)
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1586, in _make_strong
proxy = weak_type(mod, stubs)
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1202, in init_then_register
original_init(self, *args, **kwargs)
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1202, in init_then_register
original_init(self, *args, **kwargs)
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1512, in __init__
_create_methods_from_stubs(self, stubs)
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 1163, in _create_methods_from_stubs
self._c._create_methods(self, defs, rcbs, defaults)
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 898, in _try_compile_weak_script
compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
File "/workspace/ALL/pytorch_upstream/torch/jit/__init__.py", line 974, in script
ast = get_jit_def(obj)
File "/workspace/ALL/pytorch_upstream/torch/jit/frontend.py", line 156, in get_jit_def
type_line = torch.jit.annotations.get_type_line(source)
File "/workspace/ALL/pytorch_upstream/torch/jit/annotations.py", line 136, in get_type_line
raise RuntimeError("Return type line '# type: (...) -> ...' not found on multiline "
RuntimeError: Return type line '# type: (...) -> ...' not found on multiline type annotation
Expected behavior
MultiheadAttention module can be used in ScriptModule. After the error above is resolved, the next one I think will be .new_zeros not being scriptable.
Environment
Fresh compile from master (09f22d1) python 3.6
Metadata
Metadata
Assignees
Labels
module: nnRelated to torch.nnRelated to torch.nnoncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue