-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Closed
Copy link
Labels
empathy-dayLabel for issues from user empathy daysLabel for issues from user empathy dayshigh prioritymodule: aotinductoraot inductoraot inductormodule: dynamic shapesoncall: exportoncall: pt2triage review
Description
🐛 Describe the bug
NaViT (Native Resolution Vision Transformer) is a variant of the ViT. It can be draft-exported but failed when running aoti_compile_and_package
To repro:
- Install vit-pytorch
pip install vit-pytorch- Run the following:
import os
import torch
from torch.export._draft_export import draft_export
from vit_pytorch.na_vit import NaViT
def main():
v = NaViT(
image_size=256,
patch_size=32,
num_classes=1000,
dim=1024,
depth=6,
heads=16,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1,
token_dropout_prob=0.1,
)
v.eval()
with torch.no_grad():
imgs = [
[torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
[torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
[torch.randn(3, 64, 256)],
]
example_inputs = (imgs,)
print("Running torch export...")
draft_ep, _ = draft_export(v, example_inputs)
print("Running AOT Compile...")
aoti_model_path = torch._inductor.aoti_compile_and_package(
draft_ep,
example_inputs,
package_path=os.path.join(os.getcwd(), "navit.pt2"),
)
if __name__ == "__main__":
main()Draft mode export works, but AOTI failed with the following error
Traceback (most recent call last):
File "/home/yimingzhou/vit-pytorch/navit_demo.py", line 44, in <module>
main()
File "/home/yimingzhou/vit-pytorch/navit_demo.py", line 36, in main
aoti_model_path = torch._inductor.aoti_compile_and_package(
File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 101, in aoti_compile_and_package
return aoti_compile_and_package_debug_wrapper(
File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 192, in aoti_compile_and_package_debug_wrapper
raise e
File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 170, in aoti_compile_and_package_debug_wrapper
return _aoti_compile_and_package_inner(
File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 130, in _aoti_compile_and_package_inner
aoti_files = aot_compile(m, args, kwargs, options=inductor_configs) # type: ignore[arg-type]
File "/home/yimingzhou/pytorch/torch/_inductor/__init__.py", line 300, in aot_compile
return compile_fx_aot(
File "/home/yimingzhou/pytorch/torch/_inductor/compile_fx.py", line 1254, in compile_fx_aot
compiled_artifacts = compile_fx(
File "/home/yimingzhou/pytorch/torch/_inductor/compile_fx.py", line 1436, in compile_fx
return compile_fx(
File "/home/yimingzhou/pytorch/torch/_inductor/compile_fx.py", line 1478, in compile_fx
return compile_fx(
File "/home/yimingzhou/pytorch/torch/_inductor/compile_fx.py", line 1683, in compile_fx
gm, graph_signature = aot_export_module(
File "/home/yimingzhou/pytorch/torch/_functorch/aot_autograd.py", line 1278, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
File "/home/yimingzhou/pytorch/torch/_functorch/aot_autograd.py", line 1517, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
File "/home/yimingzhou/pytorch/torch/_functorch/aot_autograd.py", line 527, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/home/yimingzhou/pytorch/torch/_functorch/aot_autograd.py", line 635, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
File "/home/yimingzhou/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
flat_f_outs = f(*flat_f_args)
File "/home/yimingzhou/pytorch/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
File "/home/yimingzhou/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "/home/yimingzhou/pytorch/torch/fx/interpreter.py", line 167, in run
self.env[node] = self.run_node(node)
File "/home/yimingzhou/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6571, in run_node
rebind_unbacked(detect_fake_mode().shape_env, n, result)
File "/home/yimingzhou/pytorch/torch/fx/experimental/symbolic_shapes.py", line 481, in rebind_unbacked
if u1.node.hint is not None:
AttributeError: 'int' object has no attribute 'node'
While executing %_local_scalar_dense : [num_users=5] = call_function[target=torch.ops.aten._local_scalar_dense.default](args = (%amax,), kwargs = {})
Original traceback:
File "/home/yimingzhou/vit-pytorch/vit_pytorch/na_vit.py", line 324, in forward
seq_arange = arange(lengths.amax().item())
Versions
main
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bobrenjc93 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @desertfire @chenyang78
Metadata
Metadata
Assignees
Labels
empathy-dayLabel for issues from user empathy daysLabel for issues from user empathy dayshigh prioritymodule: aotinductoraot inductoraot inductormodule: dynamic shapesoncall: exportoncall: pt2triage review