[GraphPartition] cache get_free_symbol_uses (#166338) #166994
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Graph partition relies on
get_free_symbol_uses()to collect symbol inputs.pytorch/torch/_inductor/scheduler.py
Lines 4869 to 4885 in ee7434b
I empirically observed that
get_free_symbol_uses()becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds toget_free_symbol_uses()for 1 node.Why? Because
get_free_symbol_uses()may recursively call anotherget_free_symbol_uses(), which could recursively run many times.pytorch/torch/_inductor/ir.py
Lines 4541 to 4543 in ee7434b
This PR fixes the issue by caching the results of
get_free_symbol_uses(). I validated on torchtitan that the issue is fixed.Pull Request resolved: #166338
(cherry picked from commit dfebdca)
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben