-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implementation of scan #134102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of scan #134102
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134102
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2d99b84 with merge base 32f3af7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Reverted while_loop for scan
ydwu4
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please take a look at the comments. Overall it's a bit overwhelming for me to review the diff. Let's clean it up according to the comments and let me know if need another round of reviewing.
|
@ydwu4, I incorporated your suggestions and updated the code. I would be more than happy about another round of review if you can. Thank you verymuch! |
|
Thank you @ydwu4 for your reviews and your help on this. I fixed your last comments and merged with the latest @albanD, the current PR has been a joint effort with @ydwu4 and went through several iterations of code reviews. It is rather lengthy, but a large portion of that is because we separated |
torch/_higher_order_ops/scan.py
Outdated
| [pytree.PyTree, pytree.PyTree], Tuple[pytree.PyTree, pytree.PyTree] | ||
| ], | ||
| init: pytree.PyTree, | ||
| input: pytree.PyTree, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend making everything after input kwarg-only
| ], | ||
| init: pytree.PyTree, | ||
| input: pytree.PyTree, | ||
| dim: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is anyone asking for dim != 0? The API might be simpler if we don't have this initially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, I see that associative_scan has a dim argument so this makes sense to have
| sub_args, | ||
| sub_kwargs={}, | ||
| description="scan_combine", | ||
| description="associativ_scan_combine_fn", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
associative_scan_combine_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
|
||
| combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) | ||
| combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm) | ||
| combine_fn_name = add_subgraph(tx, "associatie_scan_combine_fn", combine_gm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This "associative_scan_combine_fn"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
test/functorch/test_control_flow.py
Outdated
| torch.float16, | ||
| torch.float32, | ||
| torch.float64, | ||
| torch.int8, | ||
| torch.int16, | ||
| torch.int32, | ||
| torch.int64, | ||
| torch.complex64, | ||
| torch.complex128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why test so much? This feels like overkill
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, lets keep the commonly seen types e.g. float32, int8, maybe a torch.complex64?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from an API perspective:
- please make everything after
inputkwarg only - consider not using "input" as the name of the argument -- input is a special python keyword.
|
Waiving PR length check since every line is carefully reviewed and cannot be logically broken down |
|
Talk to @bohnstingl offiline: we'll try to stash the changes to all associative_scan tests and added it back as a follow-up pr. |
Refactored generic_scan implementation to increase code re-use
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This operation is supposed to be the pendant to the `associative_scan`, but can operate with non-associative functions. @ydwu4 Pull Request resolved: pytorch#134102 Approved by: https://github.com/ydwu4
| """ | ||
| if not callable(combine_fn): | ||
| raise RuntimeError("Combine_fn must be a callable, but got {combine_fn}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be an f-string?
This operation is supposed to be the pendant to the
associative_scan, but can operate with non-associative functions.@ydwu4
cc @albanD @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec