-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Improvements for associative_scan - combine_mode #133012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…torch into generic_associative_scan_2
…torch into generic_associative_scan_2
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133012
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6d1d698 with merge base bb22132 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/functorch/test_control_flow.py
Outdated
| @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") | ||
| @parametrize("reverse", [False, True]) | ||
| @parametrize("combine_mode", ["pointwise", "generic"]) | ||
| @parametrize("device", [torch.device("cuda")]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the device argument necessary? Can we delete it? Same for other tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the device argument should test CPU and CUDA tensors. I updated the testcases to reflect this. In case of combine_mode=pointwise and the CPU device, the test is "skipped".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you figure out why combine_mode=pointwise x CPU fails? Doesn't need to solve it in this PR. Maybe instead of skip, we xfail it with a proper reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, my understanding was that because of the lowering to triton only the CUDA device is supported? I double-checked and this seems to be the case.
Regarding the xfail: I don’t know how to properly do the xfail. My specific problem is that only a subset of all the parameters of a test fail, e.g., CPU x pointwise. Is there an example that I can look at?
I tried: xfail_inherited_tests, xfail without much success
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, if only cuda is supported for "pointwise", how the first a few runs are OK? We could list it out as a new test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've marked the combine_mode='pointwise' x CPU testcases as skipped and added a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that ROCm is also failing these pointwise tests, even though we are using cuda device. Any ideas here?
test/inductor/test_control_flow.py
Outdated
| class AssociativeScanTests(TestCase): | ||
| @requires_gpu | ||
| @parametrize("device", [torch.device("cuda")]) | ||
| @parametrize("combine_mode", ["generic"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can add "pointwise" to combine_mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, this can be done. I extended the testcase. However, this is the test that currently fails with the weird behavior of the flip operation that I mentioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Wait for ci
Fixed lintrunner issues Added skip and expected fail decorators to flip tests
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @bohnstingl @ydwu4 @Chillee @zou3519 cc: @jeffdaily This change seems to have broken some ROCm tests, could you help us pinpoint what may be the issue here. Or some pointers on how we can debug this. Snippet: After running this locally I can see that we are only failing the pointwise combine_fn: |
|
Hi @jataylo, On my local machine the test are running fine. However, on a different note, me and @ydwu4 are currently working on a related feature to the |
|
We probably can skip ROCM tests on the associative_scan tests. From the error log, ir.Scan.Create in the lowering logic returns None. Seems like a trition x rocm issue or something. |
|
Hey @bohnstingl , @ydwu4 full error log here https://ossci-raw-job-status.s3.amazonaws.com/log/30067645445 I don't see this getting to any triton lowering before failing, seems like this is moreso a pytorch logic issue rather than a triton issue. We were passing scan UTs before this change too so would like to figure out what is going on. EDIT: this also fails with the eager compile backend. |
#133012 caused a regression on ROCm causing pointwise scan tests to fail ``` ERROR: test_pointwise_associative_scan_tuple_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_tuple_reverse_False_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_complex_pytree_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_complex_pytree_reverse_False_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_binary_operator_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_binary_operator_reverse_False_combine_mode_pointwise_cuda ``` Skipping temporarily while triage is underway. Full log: https://ossci-raw-job-status.s3.amazonaws.com/log/30067645445 ``` File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1020, in call_function out = lowerings[target](*args, **kwargs) # type: ignore[index] File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 363, in wrapped out = decomp_fn(*args, **kwargs) File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 6245, in associative_scan raise RuntimeError("Unable to generate code for associative_scan op") torch._inductor.exc.LoweringException: RuntimeError: Unable to generate code for associative_scan op ``` NOTE: even "eager" backend fails ``` File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_higher_order_ops/associative_scan.py", line 338, in associative_scan_op_dense raise NotImplementedError("associative_scan is not implemented for eager") NotImplementedError: associative_scan is not implemented for eager ``` Pull Request resolved: #135995 Approved by: https://github.com/malfet
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `combine_mode`, which can be either `pointwise` (default) or `generic`. In case of `generic`, the `associative_scan` is more flexible and allows also to perform non-pointwise functions. This PR has been derived from pytorch#129307. @ydwu4 @Chillee @zou3519 Pull Request resolved: pytorch#133012 Approved by: https://github.com/ydwu4
…ch#135995) pytorch#133012 caused a regression on ROCm causing pointwise scan tests to fail ``` ERROR: test_pointwise_associative_scan_tuple_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_tuple_reverse_False_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_complex_pytree_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_complex_pytree_reverse_False_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_binary_operator_reverse_True_combine_mode_pointwise_cuda ERROR: test_pointwise_associative_scan_binary_operator_reverse_False_combine_mode_pointwise_cuda ``` Skipping temporarily while triage is underway. Full log: https://ossci-raw-job-status.s3.amazonaws.com/log/30067645445 ``` File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1020, in call_function out = lowerings[target](*args, **kwargs) # type: ignore[index] File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 363, in wrapped out = decomp_fn(*args, **kwargs) File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 6245, in associative_scan raise RuntimeError("Unable to generate code for associative_scan op") torch._inductor.exc.LoweringException: RuntimeError: Unable to generate code for associative_scan op ``` NOTE: even "eager" backend fails ``` File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_higher_order_ops/associative_scan.py", line 338, in associative_scan_op_dense raise NotImplementedError("associative_scan is not implemented for eager") NotImplementedError: associative_scan is not implemented for eager ``` Pull Request resolved: pytorch#135995 Approved by: https://github.com/malfet
This is part of a series of PRs to improve the functionality of the
associatve_scanfunctionality. This specific PR introduces acombine_mode, which can be eitherpointwise(default) orgeneric. In case ofgeneric, theassociative_scanis more flexible and allows also to perform non-pointwise functions. This PR has been derived from #129307.@ydwu4 @Chillee @zou3519
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec