Skip to content

Conversation

@bohnstingl
Copy link
Collaborator

This is part of a series of PRs to improve the functionality of the associatve_scan functionality. This specific PR fixes issues with the current vmap implementation. This PR has been derived from #129307.

@ydwu4 @Chillee @zou3519

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133013

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2b29abd with merge base 2ba60a1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zou3519 zou3519 requested review from Chillee, ydwu4 and zou3519 August 9, 2024 12:32
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 9, 2024

with torch._dynamo.utils.disable_cache_limit():
associative_scan1 = torch.compile(
torch.vmap(associative_scan_fct, in_dims=0), fullgraph=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this seems to be testing torch.compile x vmap x associative_scan vs vmap x associative_scan. We probably should test a simpler version vmap x associative_scan vs associative_scan?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct. I split this into two testcases. One for torch.compile x vmap x associative_scan vs associative_scan and one for vmap x associative_scan vs associative_scan. However, there is a problem when using vmap x associative_scan

torch._dynamo.exc.Unsupported: torch.func.vmap(fn) requires the function to be inlined by dynamo

Do you have any suggestions how to handle this?


res = associative_scan_op(combine_fn, input_unwrapped, dim + 1)
with interpreter.lower():
res = associative_scan_op(combine_fn, input_unwrapped, dim + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are vmap's in_dims, out_dims handled? Can you add some tests to convince me this is correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now a requirement that the scan dim and the in_dims of vmap needs to be different. When vmapping, the batch dimension is moved to the 0-th dimension and thus the associative_scan is performed on dim+1. Then the result is concatenated together and the 0-th dimension is moved to the out_dims as specified for the vmap. I added a dedicated testcase to test this behavior. Does this convince you?

@bohnstingl bohnstingl requested a review from ydwu4 August 15, 2024 20:18
@unittest.skipIf(not SM70OrLater, "triton")
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
@parametrize("device", [torch.device("cuda")])
def test_pointwise_associative_scan_vmap(self, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 208 to 209
if dim in input_bdims:
raise ValueError("Vmap in_dim may not conincide with dim of associative_scan")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ever possible for a user to hit this error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the user can hit this error if the call is invoked as

def associative_scan_fct(x):
    return associative_scan(add, x, 0, reverse=reverse)
 associative_scan1 = torch.vmap(associative_scan_fct, in_dims=0, out_dims=0)

The point here is that the user has full over what is the scan dimension and what is the vmap dimension. The two may coincide, which is what's prevented here.

Copy link
Contributor

@zou3519 zou3519 Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dimension passed to scan inside the vmap is different from the in_dims in vmap. Inside the vmap, dimension 0 actually means "the 0th dimension not including the vmapped dimension". Based on that I am not sure we can actually hit the assertion -- the vmapped dimension should always be different from the scan dimension.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then I think something may be off here. This code snippet

x = torch.tile(
    torch.unsqueeze(
        torch.arange(
            0, 10, device=device, dtype=torch.float32, requires_grad=True
        ),
        0,
    ),
    (4, 1),
)
torch.compiler.reset()

def associative_scan_fct(x):
    return associative_scan(add, x, 0, reverse=reverse)

associative_scan1 = torch.compile(
    torch.vmap(associative_scan_fct, in_dims=1, out_dims=1), fullgraph=True
)

result1 = associative_scan1(x)

Calls the vmap with the input x of shape 4x10, a in_dims=1 and a scan dim of 0. The result is that input_bdims=[1] in associative_scan_batch_rule. Now, when I change in_dims=0, then input_bdims=[0] and the scan dim is 0.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at a high level this looks reasonable to me. I suggested additional test cases

@bohnstingl
Copy link
Collaborator Author

@zou3519 thank you for looking at the code. I will add the additional testcases as you suggested.
However, I would have two questions where I would love to get your opinion on:

  1. There is an issue when only using vmap with associative_scan. For example, the testcase test_pointwise_associative_scan_vmap fails because with the error
torch._dynamo.exc.Unsupported: If you are reaching here, it means dynamo failed for one of the following reasons:
- Calling torch.func.vmap(compiled_fn) function from eager mode is not supported. Ensure that torch.func.vmap is also wrapped within a torch.compile function. For more information, see PyTorch issue #128711.
- torch.func.vmap(fn) requires the function to be inlined by dynamo
  1. I integrated the reverse flag from this PR and I stumbled over the following issue. When using vmap with associative_scan, e.g., in testcase test_pointwise_associative_scan_vmap_comp, then the flip operation in the case of reverse=True causes the shape of the leaves to change and the bdim is not the same as for the case with reverse=False
    For example, for the testcase test_pointwise_associative_scan_vmap_comp, inside associative_scan_batch_rule:
    The input shape is 4x10 and the bdim is 1 for reverse=False, while for reverse=True, the shape is 10x4 and the bdim is 0. If I just remove the flip operation and just execute leaves = [elem for elem in leaves], the bdim does not change.

@zou3519
Copy link
Contributor

zou3519 commented Aug 20, 2024

For (2) -- there's no guarantee that the bdim passed to the operator is the same, even for the same operator. The value of the bdim is an implementation detail of vmap. Because of this, many vmap rules will permute the bdim to the front of the tensor

@zou3519
Copy link
Contributor

zou3519 commented Aug 20, 2024

For (1): this is a known issue (#134000). For the additional testcases I suggested, we could refactor opinfo_vmap_test to apply compile. I'm not sure how to resolve this, this needs more thinking.

@bohnstingl
Copy link
Collaborator Author

Thank you @zou3519 for your comments.

For (2) -- there's no guarantee that the bdim passed to the operator is the same, even for the same operator. The value of the bdim is an implementation detail of vmap. Because of this, many vmap rules will permute the bdim to the front of the tensor

If for vmap there is no guarantee that the same bdim is used, then the flip operation would be problematic, as this relies on dim being the same always. So if with vmap x associative_scan, associative_scan is invoked once with 4x10, bdim=1 and once with 10x4, bdim[0], then the flip with a fixed dim wouldn't work?
If there a way to detect the vmap call in associative_scan, for example with is_batchedtensor? Then the flip could be deferred to associative_scan_batch_rule, where the bdim is always moved to the 0-th dimension. However, this currently fails with
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function <built-in method is_batchedtensor of PyCapsule object at 0x7fb6a0a25f80>

The second problem I am facing is that test_vmap.py raises:
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

In general, I am struggling quite a bit with the test_vmap.py as it runs thousands of tests and it is difficult for me to isolate my newly added tests and debug them. Is there a way to specifically run only the part for the associative_scan, to better figure out whats wrong?

@bohnstingl bohnstingl requested a review from zou3519 August 22, 2024 13:36
@zou3519
Copy link
Contributor

zou3519 commented Aug 26, 2024

In general, I am struggling quite a bit with the test_vmap.py as it runs thousands of tests and it is difficult for me to isolate my newly added tests and debug them. Is there a way to specifically run only the part for the associative_scan, to better figure out whats wrong?

python test/test_vmap.py -v -k "your_test_name_here"

@bohnstingl
Copy link
Collaborator Author

@zou3519 since my understanding is still that vmap does cannot guarantee that the dimensions are not "shuffled", and I cannot detect the vmap case in associative_scan using is_batchedtensor, I don't know how to handle the reverse flag. Thus, I marked the test that involve reverse and vmap as skip.

Furthermore, I would have some questions regarding test_vmap.py. I tried to adjust the opinfo_vmap_test to account for torch.compile and I execute python test_vmap.py -v -k test_vmap_exhaustive_associative_scan_cuda_float32. However, I was unsuccessful in implementing the test properly as this part

vmapvmap_output = torch.compile(vmap(
        vmap(f, inner_in_dims, out_dims=out_dim), outer_in_dims, out_dims=out_dim
    ))(dummy, *batched_args, **kwarg_values)

Causes issues for the associative_scan, which I am not sure how to resolve. In particular:

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method _remove_batch_dim of PyCapsule object at 0x7fd605232d00>(*(BatchedTensor(lvl=1, bdim=0, value=
    FakeTensor(..., device='cuda:0', size=(2, s1, s2, s3))
), 2, 1, 0), **{}):
Cannot call sizes() on tensor with symbolic sizes/strides
Exception raised from throw_cannot_call_with_symbolic at /data_malta3_ssd/pytorch_git/c10/core/TensorImpl.cpp:298 (most recent call first):

Using

vmapvmap_output = torch.compile(vmap(
        torch.compile(vmap(f, inner_in_dims, out_dims=out_dim)), outer_in_dims, out_dims=out_dim
    ))(dummy, *batched_args, **kwarg_values)

Raises a shape mismatch, because if the associative_scan_batch_rule is called, we don't explicitly add a dimension in case no bdim is given to vmap.

Could you please take a look and help me out?

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Oct 27, 2024
@github-actions github-actions bot closed this Nov 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants