Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Dec 26, 2024

Stack from ghstack (oldest at bottom):

Summary

Addresses: #143840

Current dynamic failing test: test/inductor/test_flex_attention.py::TestBlockMask::test_compiling_create_block_mask_no_recompile - torch.dynamo.exc.TorchRuntimeError: Failed running call_method scatter(*(BatchedTensor(lvl=2,...

with

CC @zou3519 for ideas on why this failing

  File "/home/drisspg/meta/pytorch/torch/_dynamo/utils.py", line 2694, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/home/drisspg/meta/pytorch/torch/_dynamo/utils.py", line 2678, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.TorchRuntimeError: Failed running call_method scatter_(*(BatchedTensor(lvl=2, bdim=0, value=
    BatchedTensor(lvl=1, bdim=0, value=
        FakeTensor(..., device='cuda:0',
                   size=(s0, s1, (s2 + 127//128), ((s3 + 127//128)) + 1),
                   dtype=torch.int32)
    )
), 1, BatchedTensor(lvl=2, bdim=0, value=
    BatchedTensor(lvl=1, bdim=0, value=
        FakeTensor(..., device='cuda:0',
                   size=(s0, s1, (s2 + 127//128), (s3 + 127//128)), dtype=torch.int64)
    )
), BatchedTensor(lvl=2, bdim=0, value=
    BatchedTensor(lvl=1, bdim=0, value=
        FakeTensor(..., device='cuda:0',
                   size=(s0, s1, (s2 + 127//128), (s3 + 127//128)), dtype=torch.int32)
    )
)), **{}):
Cannot call sizes() on tensor with symbolic sizes/strides
Exception raised from throw_cannot_call_with_symbolic at /home/drisspg/meta/pytorch/c10/core/TensorImpl.cpp:291 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x98 (0x7fc0fd78fbe8 in /home/drisspg/meta/pytorch/torch/lib/libc10.so)
frame #1: c10::TensorImpl::throw_cannot_call_with_symbolic(char const*) const + 0x8d (0x7fc0fd738181 in /home/drisspg/meta/pytorch/torch/lib/libc10.so)
frame #2: at::functorch::BatchedTensorImpl::sizes_custom() const + 0x5c (0x7fc0ec1a0e0c in /home/drisspg/meta/pytorch/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x179d11f (0x7fc0ec19d11f in /home/drisspg/meta/pytorch/torch/lib/libtorch_cpu.so)
frame #4: <unknown function> + 0x64036c (0x7fc0fde4036c in /home/drisspg/meta/pytorch/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x63ceed (0x7fc0fde3ceed in /home/drisspg/meta/pytorch/torch/lib/libtorch_python.so)
frame #6: <unknown function> + 0x17c145b (0x7fc0ec1c145b in /home/drisspg/meta/pytorch/torch/lib/libtorch_cpu.so)
frame #7: <unknown function> + 0x17ab401 (0x7fc0ec1ab401 in /home/drisspg/meta/pytorch/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x17a614c (0x7fc0ec1a614c in /home/drisspg/meta/pytorch/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x64036c (0x7fc0fde4036c in /home/drisspg/meta/pytorch/torch/lib/libtorch_python.so)
frame #10: <unknown function> + 0x63ceed (0x7fc0fde3ceed in /home/drisspg/meta/pytorch/torch/lib/libtorch_python.so)
frame #11: at::_ops::scatter__src::call(at::Tensor&, long, at::Tensor const&, at::Tensor const&) + 0x3d1 (0x7fc0ecec4281 in /home/drisspg/meta/pytorch/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x41b9c2 (0x7fc0fdc1b9c2 in /home/drisspg/meta/pytorch/torch/lib/libtorch_python.so)
frame #13: <unknown function> + 0x2240a8 (0x55944f3cc0a8 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #14: _PyObject_Call + 0xb5 (0x55944f3dcb35 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #15: <unknown function> + 0x11350a (0x55944f2bb50a in /home/drisspg/.conda/envs/dev/bin/python3)
frame #16: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #17: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #18: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #19: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #20: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #21: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #22: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #23: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #24: <unknown function> + 0x11350a (0x55944f2bb50a in /home/drisspg/.conda/envs/dev/bin/python3)
frame #25: _PyObject_Call + 0x12b (0x55944f3dcbab in /home/drisspg/.conda/envs/dev/bin/python3)
frame #26: <unknown function> + 0x11350a (0x55944f2bb50a in /home/drisspg/.conda/envs/dev/bin/python3)
frame #27: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #28: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #29: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #30: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #31: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #32: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #33: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #34: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #35: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #36: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #37: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #38: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #39: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #40: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #41: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #42: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #43: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #44: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #45: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #46: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #47: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #48: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #49: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #50: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #51: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #52: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #53: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #54: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #55: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #56: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #57: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #58: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #59: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #60: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)
frame #61: PyObject_Vectorcall + 0x2e (0x55944f3c0cbe in /home/drisspg/.conda/envs/dev/bin/python3)
frame #62: <unknown function> + 0x112892 (0x55944f2ba892 in /home/drisspg/.conda/envs/dev/bin/python3)


from user code:
   File "/home/drisspg/meta/pytorch/torch/nn/attention/flex_attention.py", line 890, in create_block_mask
    block_mask = _create_sparse_block_from_block_mask(
  File "/home/drisspg/meta/pytorch/torch/nn/attention/flex_attention.py", line 762, in _create_sparse_block_from_block_mask
    return BlockMask.from_kv_blocks(
  File "/home/drisspg/meta/pytorch/torch/nn/attention/flex_attention.py", line 350, in from_kv_blocks
    q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
  File "/home/drisspg/meta/pytorch/torch/nn/attention/flex_attention.py", line 184, in _transpose_ordered
    dense = _ordered_to_dense(num_blocks_in_row, col_indices)
  File "/home/drisspg/meta/pytorch/torch/nn/attention/flex_attention.py", line 169, in _ordered_to_dense
    out = create_dense_batched(num_blocks_in_row, col_indices)
  File "/home/drisspg/meta/pytorch/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/home/drisspg/meta/pytorch/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/home/drisspg/meta/pytorch/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
  File "/home/drisspg/meta/pytorch/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
  File "/home/drisspg/meta/pytorch/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/nn/attention/flex_attention.py", line 162, in create_dense_one
    dense_mask.scatter_(1, valid_indices.to(torch.int64), values)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


To execute this test, run the following from the base repo dir:
    python test/inductor/test_flex_attention.py TestBlockMask.test_compiling_create_block_mask_no_recompile

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @Chillee @yanboliang @BoyuanFeng

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143872

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 516bd4f with merge base a174ee2 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Dec 26, 2024
ghstack-source-id: 182ac2d
Pull Request resolved: #143872
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jan 2, 2025
ghstack-source-id: 1dc5c75
Pull Request resolved: #143872
@Chillee
Copy link
Collaborator

Chillee commented Jan 7, 2025

Why is this not cudagraphable? 🤔

@Chillee
Copy link
Collaborator

Chillee commented Jan 17, 2025

Yeah, this is the wrong way to solve this problem. The issue is in vmap.

@zou3519
Copy link
Contributor

zou3519 commented Jan 23, 2025

I think we're missing a .sym_sizes call somewhere in the C++ implementation for vmap. @drisspg can you get a C++ stack trace in debug mode? (so we can see function names)

@drisspg
Copy link
Contributor Author

drisspg commented Feb 4, 2025

Fixed in internals of vmap

@drisspg drisspg closed this Feb 4, 2025
@github-actions github-actions bot deleted the gh/drisspg/106/head branch March 7, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants