Skip to content

Conversation

@bohnstingl
Copy link
Collaborator

@bohnstingl bohnstingl commented Aug 8, 2024

This is part of a series of PRs to improve the functionality of the associatve_scan functionality. This specific PR introduces a reverse flag to the associative_scan to establish a similar interface as for jax.associative_scan. This PR has been derived from #129307.

@ydwu4 @Chillee @zou3519

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133011

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1e3391a with merge base 161cc13 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

leaves, spec = pytree.tree_flatten(input)

if reverse:
leaves = [torch.flip(elem, [dim]) for elem in leaves]
Copy link
Contributor

@vadimkantorov vadimkantorov Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice elem.flip(dim) also works (no need to construct a list [dim]) but unfortunately this is not documented properly:

Copy link
Contributor

@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. Left some comments for test organizations. We generally don't want unit test to be too complicated or take too long to run and we should split a large test into different tests that test different things.

print("Flip test fails for backends: " + str(fails_for_backend))
self.assertEqual(len(fails_for_backend), 0)

for n in range(20):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

20 might be too large? Don't want the unit tests take too long

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the problem is that with the combine_fn='generic' that will follow, the test will randomly fail at some point. Thus I wanted to keep it like this. However, if you feel that it takes too long, I can reduce.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why it fails randomly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails for n=9 for the case combine_mode='generic' . The output of

torch.compile(
    associative_scan, backend=backend, fullgraph=True
) 

operation is
tensor([36, 36, 35, 33, 30, 26, 21, 15, 8], device='cuda:0') which is correct, while the output of

associative_scan2 = associative_scan

is
tensor([36, 15, 35, 0, 30, 0, 21, 0, 8], device='cuda:0'), which is incorrect.

This problem only appears for n=9, while for other values of n, the outputs are correct.

@ydwu4
Copy link
Contributor

ydwu4 commented Aug 8, 2024

Can skip the failed test with from torch.testing._internal.common_cuda import SM70OrLater unittest.skipIf(not SM70OrLater).

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 9, 2024
bohnstingl added a commit to bohnstingl/pytorch that referenced this pull request Aug 10, 2024
Copy link
Contributor

@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Wait for CI.

@ydwu4 ydwu4 added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 15, 2024
@ydwu4
Copy link
Contributor

ydwu4 commented Aug 16, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@ydwu4 ydwu4 added the topic: not user facing topic category label Aug 16, 2024
@ydwu4
Copy link
Contributor

ydwu4 commented Aug 16, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@ydwu4
Copy link
Contributor

ydwu4 commented Aug 16, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 0 checks:

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Sep 13, 2024
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `reverse` flag to the `associative_scan` to establish a similar interface as for `jax.associative_scan`. This PR has been derived from #129307.

@ydwu4 @Chillee @zou3519

Pull Request resolved: #133011
Approved by: https://github.com/ydwu4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants