-
Notifications
You must be signed in to change notification settings - Fork 26.3k
use a fast expand algorithm #135999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use a fast expand algorithm #135999
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135999
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a097c35 with merge base 7647c39 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| try: | ||
| return sympy.expand(r) | ||
| return fast_expand(r) | ||
| except RecursionError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw it might be OK to remove the hasattr/RecursionError guarding with the change now, if it's now trivially obvious we won't blow out the stack expanding with the new code
| IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) | ||
|
|
||
|
|
||
| def _expandsums(args): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at least a teeny type signature would be nice lol
| for add in adds: | ||
| result = [a * b for a, b in itertools.product(result, add.args)] | ||
|
|
||
| result = sympy.Add(*result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't do make_args I guess, since we have to combine common terms
| # checks when re-creating objects. | ||
| new_args = [fast_expand(arg) for arg in expr.args] | ||
| if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): | ||
| return fast_expand(expr.func(*new_args)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, RecursionError danger here, I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. However I'm not sure if replacing this line with expr = expr.func(*new_args) would work.
| if exp > 1: | ||
| return sympy.expand_multinomial(expr, deep=False) | ||
| elif exp < 0: | ||
| return 1 / sympy.expand_multinomial(1 / expr, deep=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is impossible, because I added a FloatPow for this case
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#135999 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):