Skip to content

[AOTI] AOT Compile NaViT - AttributeError: 'int' object has no attribute 'node' #140625

@yiming0416

Description

@yiming0416

🐛 Describe the bug

NaViT (Native Resolution Vision Transformer) is a variant of the ViT. It can be draft-exported but failed when running aoti_compile_and_package

To repro:

  1. Install vit-pytorch
pip install vit-pytorch
  1. Run the following:
import os

import torch
from torch.export._draft_export import draft_export
from vit_pytorch.na_vit import NaViT


def main():
    v = NaViT(
        image_size=256,
        patch_size=32,
        num_classes=1000,
        dim=1024,
        depth=6,
        heads=16,
        mlp_dim=2048,
        dropout=0.1,
        emb_dropout=0.1,
        token_dropout_prob=0.1,
    )

    v.eval()
    with torch.no_grad():

        imgs = [
            [torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
            [torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
            [torch.randn(3, 64, 256)],
        ]

        example_inputs = (imgs,)
        print("Running torch export...")
        draft_ep, _ = draft_export(v, example_inputs)

        print("Running AOT Compile...")
        aoti_model_path = torch._inductor.aoti_compile_and_package(
            draft_ep,
            example_inputs,
            package_path=os.path.join(os.getcwd(), "navit.pt2"),
        )


if __name__ == "__main__":
    main()

Draft mode export works, but AOTI failed with the following error

Traceback (most recent call last):
  File "/home/yimingzhou/vit-pytorch/navit_demo.py", line 44, in <module>
    main()
  File "/home/yimingzhou/vit-pytorch/navit_demo.py", line 36, in main
    aoti_model_path = torch._inductor.aoti_compile_and_package(
  File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 101, in aoti_compile_and_package
    return aoti_compile_and_package_debug_wrapper(
  File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 192, in aoti_compile_and_package_debug_wrapper
    raise e
  File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 170, in aoti_compile_and_package_debug_wrapper
    return _aoti_compile_and_package_inner(
  File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 130, in _aoti_compile_and_package_inner
    aoti_files = aot_compile(m, args, kwargs, options=inductor_configs)  # type: ignore[arg-type]
  File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 300, in aot_compile
    return compile_fx_aot(
  File "/home/yimingzhou/pytorch/torch/_inductor/compile_fx.py", line 1254, in compile_fx_aot
    compiled_artifacts = compile_fx(
  File "/home/yimingzhou/pytorch/torch/_inductor/compile_fx.py", line 1436, in compile_fx
    return compile_fx(
  File "/home/yimingzhou/pytorch/torch/_inductor/compile_fx.py", line 1478, in compile_fx
    return compile_fx(
  File "/home/yimingzhou/pytorch/torch/_inductor/compile_fx.py", line 1683, in compile_fx
    gm, graph_signature = aot_export_module(
  File "/home/yimingzhou/pytorch/torch/_functorch/aot_autograd.py", line 1278, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
  File "/home/yimingzhou/pytorch/torch/_functorch/aot_autograd.py", line 1517, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
  File "/home/yimingzhou/pytorch/torch/_functorch/aot_autograd.py", line 527, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/home/yimingzhou/pytorch/torch/_functorch/aot_autograd.py", line 635, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/yimingzhou/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
    flat_f_outs = f(*flat_f_args)
  File "/home/yimingzhou/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/home/yimingzhou/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "/home/yimingzhou/pytorch/torch/fx/interpreter.py", line 167, in run
    self.env[node] = self.run_node(node)
  File "/home/yimingzhou/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6571, in run_node
    rebind_unbacked(detect_fake_mode().shape_env, n, result)
  File "/home/yimingzhou/pytorch/torch/fx/experimental/symbolic_shapes.py", line 481, in rebind_unbacked
    if u1.node.hint is not None:
AttributeError: 'int' object has no attribute 'node'

While executing %_local_scalar_dense : [num_users=5] = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%amax,), kwargs = {})
Original traceback:
File "/home/yimingzhou/vit-pytorch/vit_pytorch/na_vit.py", line 324, in forward
    seq_arange = arange(lengths.amax().item())

Versions

main

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @desertfire @chenyang78

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions