-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Introduce torch.sym_add, variadic add #138660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Partially addresses #128150 When you have big sums of values, we end up computing long chains of binary addition in our FX graph representation. Not only is this ugly, it also is quadratic, as the sympy.Add constructor is O(N) in number of arguments. Instead, ensure that we maintain the summation as a single FX node so we can do the entire addition all in one go. Signed-off-by: Edward Z. Yang <[email protected]> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138660
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (1 Unrelated Failure)As of commit c84b2af with merge base 72dde6e ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot merge -i |
Tested internally here: https://www.internalfb.com/diff/D64057744 This is a reland after previous internal failures. main change is ``` if min is None and max is None: torch._check_is_size(size) return ``` Partially addresses #128150 When you have big sums of values, we end up computing long chains of binary addition in our FX graph representation. Not only is this ugly, it also is quadratic, as the sympy.Add constructor is O(N) in number of arguments. Instead, ensure that we maintain the summation as a single FX node so we can do the entire addition all in one go. Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 ezyang SherlockNoMad EikanWang wenzhe-nrv voznesenskym penguinwu Guobing-Chen zhuhaozhe blzheng jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov rec [ghstack-poisoned]
Partially addresses #128150 When you have big sums of values, we end up computing long chains of binary addition in our FX graph representation. Not only is this ugly, it also is quadratic, as the sympy.Add constructor is O(N) in number of arguments. Instead, ensure that we maintain the summation as a single FX node so we can do the entire addition all in one go. Signed-off-by: Edward Z. Yang <ezyangmeta.com> ghstack-source-id: a170651 Pull Request resolved: #138660
|
rebase |
| # Special case for sum on tuple/list of ints | ||
| if ( | ||
| self.fn is builtins.sum | ||
| and len(args) == 1 | ||
| and not kwargs | ||
| and isinstance(args[0], (variables.ListVariable, variables.TupleVariable)) | ||
| and all( | ||
| (isinstance(x, variables.ConstantVariable) and isinstance(x.value, int)) | ||
| or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int) | ||
| for x in args[0].items | ||
| ) | ||
| ): | ||
| return variables.SymNodeVariable.create( | ||
| tx, | ||
| tx.output.create_proxy( | ||
| "call_function", | ||
| torch.sym_sum, | ||
| (tuple(a.as_proxy() for a in args[0].items),), | ||
| {}, | ||
| ), | ||
| sym_num=torch.sym_sum( | ||
| [ | ||
| ( | ||
| x.value | ||
| if isinstance(x, variables.ConstantVariable) | ||
| else x.sym_num | ||
| ) | ||
| for x in args[0].items | ||
| ] | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to create a dispatch registry to make the code easier to maintain.
check_fn, dispatch_fn = self.shortcuts.get(self.fn, (None, None))
if check_fn is not None and check_fn(args, kwargs):
return dispatch_fn(self, tx, args, kwargs)|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Tested internally here: https://www.internalfb.com/diff/D64057744 This is a reland after previous internal failures. main change is ``` if min is None and max is None: torch._check_is_size(size) return ``` Partially addresses #128150 When you have big sums of values, we end up computing long chains of binary addition in our FX graph representation. Not only is this ugly, it also is quadratic, as the sympy.Add constructor is O(N) in number of arguments. Instead, ensure that we maintain the summation as a single FX node so we can do the entire addition all in one go. Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: #138660 Approved by: https://github.com/ezyang, https://github.com/bobrenjc93
Stack from ghstack (oldest at bottom):
Tested internally here: https://www.internalfb.com/diff/D64057744
This is a reland after previous internal failures.
main change is
Partially addresses #128150
When you have big sums of values, we end up computing long chains of
binary addition in our FX graph representation. Not only is this ugly,
it also is quadratic, as the sympy.Add constructor is O(N) in number
of arguments. Instead, ensure that we maintain the summation as a
single FX node so we can do the entire addition all in one go.
Signed-off-by: Edward Z. Yang [email protected]
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @ezyang @SherlockNoMad @EikanWang @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @zhuhaozhe @blzheng @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec