Skip to content

Conversation

@bohnstingl
Copy link
Collaborator

@bohnstingl bohnstingl commented Aug 21, 2024

This operation is supposed to be the pendant to the associative_scan, but can operate with non-associative functions.

@ydwu4

cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134102

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2d99b84 with merge base 32f3af7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@bohnstingl bohnstingl marked this pull request as ready for review August 28, 2024 13:11
@janeyx99 janeyx99 requested a review from ydwu4 August 30, 2024 04:45
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 30, 2024
Copy link
Contributor

@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at the comments. Overall it's a bit overwhelming for me to review the diff. Let's clean it up according to the comments and let me know if need another round of reviewing.

@bohnstingl
Copy link
Collaborator Author

@ydwu4, I incorporated your suggestions and updated the code.
I also merge it with the latest main, which includes the associative_scan feature. So there have been a lot of new test cases added and in all of them, I added the scan as well.

I would be more than happy about another round of review if you can. Thank you verymuch!

@bohnstingl bohnstingl requested a review from ydwu4 August 31, 2024 16:24
@bohnstingl
Copy link
Collaborator Author

Thank you @ydwu4 for your reviews and your help on this. I fixed your last comments and merged with the latest main.

@albanD, the current PR has been a joint effort with @ydwu4 and went through several iterations of code reviews. It is rather lengthy, but a large portion of that is because we separated scan from associative_scan on a higher level, which produced quite some code lines. Could you please take a look and maybe approve the PR length as it?

@ydwu4 ydwu4 added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 7, 2024
[pytree.PyTree, pytree.PyTree], Tuple[pytree.PyTree, pytree.PyTree]
],
init: pytree.PyTree,
input: pytree.PyTree,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend making everything after input kwarg-only

],
init: pytree.PyTree,
input: pytree.PyTree,
dim: int = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is anyone asking for dim != 0? The API might be simpler if we don't have this initially.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I see that associative_scan has a dim argument so this makes sense to have

sub_args,
sub_kwargs={},
description="scan_combine",
description="associativ_scan_combine_fn",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

associative_scan_combine_fn

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm)
combine_fn_name = add_subgraph(tx, "associatie_scan_combine_fn", combine_gm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "associative_scan_combine_fn"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 1483 to 1491
torch.float16,
torch.float32,
torch.float64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.complex64,
torch.complex128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why test so much? This feels like overkill

Copy link
Contributor

@ydwu4 ydwu4 Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, lets keep the commonly seen types e.g. float32, int8, maybe a torch.complex64?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from an API perspective:

  1. please make everything after input kwarg only
  2. consider not using "input" as the name of the argument -- input is a special python keyword.

@albanD
Copy link
Collaborator

albanD commented Sep 9, 2024

Waiving PR length check since every line is carefully reviewed and cannot be logically broken down

@ydwu4
Copy link
Contributor

ydwu4 commented Sep 9, 2024

Talk to @bohnstingl offiline: we'll try to stash the changes to all associative_scan tests and added it back as a follow-up pr.

Refactored generic_scan implementation to increase code re-use
@ydwu4
Copy link
Contributor

ydwu4 commented Sep 10, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
This operation is supposed to be the pendant to the `associative_scan`, but can operate with non-associative functions.

@ydwu4

Pull Request resolved: pytorch#134102
Approved by: https://github.com/ydwu4
"""
if not callable(combine_fn):
raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be an f-string?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor open source skip-pr-sanity-checks topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants