Skip to content

dynamo export fails on einops.layers.torch.Rearrange (einops.rearrange with same pattern works) #137629

@borisfom

Description

@borisfom

🐛 Describe the bug

Here, I have encountered this issue trying to export ControlNetMaisi from MONAI.
Dynamo export fails on this code (legacy export works):
https://github.com/Project-MONAI/MONAI/blob/d2d492ec045848b5872d8333e5174086c9cdfead/monai/networks/blocks/spatialattention.py#L88
I realize there are a few workarounds here, and non-functional version was probably not designed to be in forward path, but given wide use of einops, it may be worth looking into how Rearrange can be supported in dynamo the same way rearrange() is.

Here is a repro :

import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_width, gated=False):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

    def forward(self, x, mask=None, bias=None, indices=None):

        # This works:                                                                                                                                                         
        # t = rearrange(x, "... l (h c) -> ... h l c", h=self.num_heads)                                                                                                      

        # This fails with "torch._dynamo.exc.Unsupported: call_method BuiltinVariable(str) count [ConstantVariable(), ConstantVariable()] {}"                                 
        t = Rearrange("... l (h c) -> ... h l c", h=self.num_heads)(x)

        q, k, v = t.chunk(3, dim=-1)
        a = torch.einsum("...qc,...kc->...qk", q, k)
        a = F.softmax(a, dim=-1)
        y = torch.einsum("...hqk,...hkc->...qhc", a, v)
        y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads)
        return y


mod = Attention(1024,32,32, gated=True)
inp = torch.randn(2, 79, 1024)

# This works:                                                                                                                                                                 
torch.onnx.export(mod, (inp,), 'rearrange_legacy.onnx', dynamo=False, verbose=True)
# This fails:                                                                                                                                                                 
torch.onnx.export(mod, (inp,), 'rearrange.onnx', dynamo=True)
 


### Versions

Pytorch 2.5.0a0+b465a5843b.nv24.9

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions