-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[DRAFT] call sympy.Max(a, b, evaluate=False) instead of torch.utils._sympy.functions .Max #137796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This pull request was exported from Phabricator. Differential Revision: D64252491 |
…sympy.functions .Max (pytorch#137796) Summary: --num-features=200 compile.compile_inner 16.5991s vs --num-features=400 compile.compile_inner 287.493 vs pytorch#100 num_features ``` rank: 0, world_size: 2, num_features: 100, batch_size: 10, time: 20.05s va rank: 0, world_size: 2, num_features: 100, batch_size: 10, time: 40.24s ``` pytorch#200 num_features ``` rank: 0, world_size: 2, num_features: 200, batch_size: 10, time: 20.66s rank: 0, world_size: 2, num_features: 200, batch_size: 10, time: 125.05s ``` Differential Revision: D64252491
873e16d to
4aa1df4
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64252491 |
|
As I was telling @bobrenjc93, you can't do this because the simplification done in the constructor is actually load bearing for unbacked reasoning |
|
And your profile is pre #133325 which fixed the bulk of the problems |
|
great if #133325 fix it then no need for this, |
|
No, it is a correctness issue. For example, if you have a test Max(1, u0, 256) == Max(u0, 256), you need to return True for it. But the easiest way for this to happen is for the Max constructor to simplify Max(1, u0, 256) into Max(u0, 256), none of our other reasoning mechanisms will work. |
mhmm where in the compiler we do this == check, and why is it not .equals() ? Max(1, u0, 256).equals(Max(u0, 256)) should be True. do you have an e2e example of this with failure, like an function that we compile that we end up generating wrong program or failing to compile it due to this? I will try to play with some examples to see if i can get one. So I have another idea that is less risky but need to benchmark it but its O(N) so : if we look at the paste. https://www.internalfb.com/phabricator/paste/view/P1644273374 so we can do something like this in O(N) under some conditions given max(max(a, b), c) if a , b and c are all summations of symbols only and there is no intersection in the symbols of a, b, c. then generate |
4aa1df4 to
e0d98ca
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64252491 |
…sympy.functions .Max (pytorch#137796) Summary: I was looking at this profile with bobrenjc93 ``` TORCH_COMPILE_STROBELIGHT=1 COMPILE_STROBELIGHT_MAX_STACK_LENGTH= 500 buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200 ``` strobelight profile link: https://fburl.com/scuba/pyperf_experimental/on_demand/lrh6erxx {F1924015721}{F1924015712} Most of the time is spent in constructing the max() node. if we pass evaluate=False we no longer have the exponential cost!? paste that show what we compute when we call pass https://www.internalfb.com/phabricator/paste/view/P1644273374 there is a clear repetition across calls, and simplification is not doing much other than flattening inputs of max. This is just draft to make sure all test pass in OSS, wonder if avoid simplification at construction can make other programs slower? if so maybe we can add this under a flag? --num-features=200 compile.compile_inner 16.5991s vs 120.824s --num-features=400 (compile.compile_inner ) 40s vs 918.024s num_features=100 ``` rank: 0, world_size: 2, num_features: 100, batch_size: 10, time: 20.05s va rank: 0, world_size: 2, num_features: 100, batch_size: 10, time: 40.24s ``` num_features=200 ``` rank: 0, world_size: 2, num_features: 200, batch_size: 10, time: 20.66s rank: 0, world_size: 2, num_features: 200, batch_size: 10, time: 125.05s ``` Differential Revision: D64252491 D64252491
e0d98ca to
630f330
Compare
|
This pull request was exported from Phabricator. Differential Revision: D64252491 |
|
I confirmed that #133325 does fix the regression that this diff try to Fix also calling simplify in statically_known_true can be a two way sword, confirmed it work for the benchmark above if |
Summary:
I was looking at this profile with @bobrenjc93
strobelight profile link: https://fburl.com/scuba/pyperf_experimental/on_demand/lrh6erxx


Most of the time is spent in constructing the max() node. if we pass evaluate=False we no longer have the exponential cost!?
paste that show what we construct when we call max https://www.internalfb.com/phabricator/paste/view/P1644273374
there is a clear repetition across calls, and simplification is not doing much other than flattening inputs of max.
This is just draft to make sure all test pass in OSS, wonder if avoid simplification at construction can make other
programs slower? if so maybe we can add this under a flag?
alternatively we can also define our own max function and customize automatic simplifications inside
https://docs.sympy.org/latest/explanation/best-practices.html#avoid-too-much-automatic-evaluation
--num-features=200
compile.compile_inner
16.5991s vs 120.824s
--num-features=400 (compile.compile_inner )
40s vs 918.024s
num_features=100
num_features=200
Differential Revision: D64252491