-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Improvements for associative_scan - vmap fixes #133013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133013
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2b29abd with merge base 2ba60a1 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/functorch/test_control_flow.py
Outdated
|
|
||
| with torch._dynamo.utils.disable_cache_limit(): | ||
| associative_scan1 = torch.compile( | ||
| torch.vmap(associative_scan_fct, in_dims=0), fullgraph=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, this seems to be testing torch.compile x vmap x associative_scan vs vmap x associative_scan. We probably should test a simpler version vmap x associative_scan vs associative_scan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is correct. I split this into two testcases. One for torch.compile x vmap x associative_scan vs associative_scan and one for vmap x associative_scan vs associative_scan. However, there is a problem when using vmap x associative_scan
torch._dynamo.exc.Unsupported: torch.func.vmap(fn) requires the function to be inlined by dynamo
Do you have any suggestions how to handle this?
|
|
||
| res = associative_scan_op(combine_fn, input_unwrapped, dim + 1) | ||
| with interpreter.lower(): | ||
| res = associative_scan_op(combine_fn, input_unwrapped, dim + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how are vmap's in_dims, out_dims handled? Can you add some tests to convince me this is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is now a requirement that the scan dim and the in_dims of vmap needs to be different. When vmapping, the batch dimension is moved to the 0-th dimension and thus the associative_scan is performed on dim+1. Then the result is concatenated together and the 0-th dimension is moved to the out_dims as specified for the vmap. I added a dedicated testcase to test this behavior. Does this convince you?
test/functorch/test_control_flow.py
Outdated
| @unittest.skipIf(not SM70OrLater, "triton") | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") | ||
| @parametrize("device", [torch.device("cuda")]) | ||
| def test_pointwise_associative_scan_vmap(self, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bohnstingl Could you also add a test that:
- adds associative_scan to the hop_db (https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/hop_db.py). This involves adding some "sample inputs" to associative scan
- calls opinfo_vmap_test? (https://github.com/pytorch/pytorch/blob/fb26b843906bbad5e28d1edccf298c74b8e00492/test/functorch/test_vmap.py#L4272C14-L4272C30)
That would be the ultimate vmap test
| if dim in input_bdims: | ||
| raise ValueError("Vmap in_dim may not conincide with dim of associative_scan") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it ever possible for a user to hit this error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the user can hit this error if the call is invoked as
def associative_scan_fct(x):
return associative_scan(add, x, 0, reverse=reverse)
associative_scan1 = torch.vmap(associative_scan_fct, in_dims=0, out_dims=0)
The point here is that the user has full over what is the scan dimension and what is the vmap dimension. The two may coincide, which is what's prevented here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dimension passed to scan inside the vmap is different from the in_dims in vmap. Inside the vmap, dimension 0 actually means "the 0th dimension not including the vmapped dimension". Based on that I am not sure we can actually hit the assertion -- the vmapped dimension should always be different from the scan dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, then I think something may be off here. This code snippet
x = torch.tile(
torch.unsqueeze(
torch.arange(
0, 10, device=device, dtype=torch.float32, requires_grad=True
),
0,
),
(4, 1),
)
torch.compiler.reset()
def associative_scan_fct(x):
return associative_scan(add, x, 0, reverse=reverse)
associative_scan1 = torch.compile(
torch.vmap(associative_scan_fct, in_dims=1, out_dims=1), fullgraph=True
)
result1 = associative_scan1(x)
Calls the vmap with the input x of shape 4x10, a in_dims=1 and a scan dim of 0. The result is that input_bdims=[1] in associative_scan_batch_rule. Now, when I change in_dims=0, then input_bdims=[0] and the scan dim is 0.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at a high level this looks reasonable to me. I suggested additional test cases
|
@zou3519 thank you for looking at the code. I will add the additional testcases as you suggested.
|
|
For (2) -- there's no guarantee that the bdim passed to the operator is the same, even for the same operator. The value of the bdim is an implementation detail of vmap. Because of this, many vmap rules will permute the bdim to the front of the tensor |
|
For (1): this is a known issue (#134000). For the additional testcases I suggested, we could refactor opinfo_vmap_test to apply compile. I'm not sure how to resolve this, this needs more thinking. |
|
Thank you @zou3519 for your comments.
If for vmap there is no guarantee that the same bdim is used, then the flip operation would be problematic, as this relies on dim being the same always. So if with vmap x associative_scan, associative_scan is invoked once with 4x10, bdim=1 and once with 10x4, bdim[0], then the flip with a fixed dim wouldn't work? The second problem I am facing is that In general, I am struggling quite a bit with the test_vmap.py as it runs thousands of tests and it is difficult for me to isolate my newly added tests and debug them. Is there a way to specifically run only the part for the |
python test/test_vmap.py -v -k "your_test_name_here" |
|
@zou3519 since my understanding is still that vmap does cannot guarantee that the dimensions are not "shuffled", and I cannot detect the vmap case in associative_scan using Furthermore, I would have some questions regarding test_vmap.py. I tried to adjust the Causes issues for the Using Raises a shape mismatch, because if the Could you please take a look and help me out? |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
This is part of a series of PRs to improve the functionality of the
associatve_scanfunctionality. This specific PR fixes issues with the current vmap implementation. This PR has been derived from #129307.@ydwu4 @Chillee @zou3519