-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Inductor] Defer reduction split after fusion #141082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141082
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit c5208d1 with merge base ecf3bae ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D66157846 |
| self.name_to_buf.update({ | ||
| buf.get_name() : buf for buf in middle_node.get_outputs() | ||
| }) | ||
| self.name_to_fused_node[middle_node.get_name()] = middle_node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hope I didn't overlook updating any other necessary fields in the scheduler...
|
@pytorchbot label "topic: not user facing" |
|
Nice! I tried this on a torchao test case and get the following error: |
|
@vkuzo thanks for sharing the error. Updates: I updated the pr to fix this error, also attached the trace + generated kernels in the diff's test plan. |
f77c862 to
90212ab
Compare
|
This pull request was exported from Phabricator. Differential Revision: D66157846 |
torch/_inductor/scheduler.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to this pr, it's to fix this test error, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, would you say more about this error ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main change here is to modify the condition of calling loop_reordering from shared_data_score == 0 to shared_data_score < config.score_fusion_memory_threshold.
In the failed test (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425), the two pointwise nodes originally have a shared read data (the scaling factor, which is just a scalar). The shared_data_score is not 0 but smaller than config.loop_ordering_after_fusion.
Before the change:
shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold
After the change:
shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused
Summary: * Added a config option `defer_reduction_split`. When it's enabled, if `num_splits` gets a `>1` result, return `ReductionHint.DEFERRED_SPLIT, 1` instead. * In scheduler, when fusing nodes, if a node is a `ReductionHint.DEFERRED_SPLIT`, ignore it for now. After all non-DEFERRED_SPLIT nodes are fused, split DEFERRED_SPLIT Reduction nodes. * I thought about a better place to put `split_reduction_nodes`. However, there are cases that the tiled operations will only occur after fusion. So I chose to run `fuse_nodes` first, and then used the fused nodes to split the reduction nodes. For example, ``` # op0. the reduction, read: arg0_1 y = x.abs().max() # op1. read: arg0_1; write: buf0; no tiling > z = x / 10.0 # op2. read: transposed buf0 z_t = z.t().contiguous().t() # Before fusion of op1 and op2: Only op1 has shared data with op0, however op1 has no tiling # After fusion of op1 and op2: fused op0_1: read arg0_1, and has a tiling that the reduction (op0) can try to match ``` * For a DEFERRED_SPLIT Reduction node, find all nodes that has shared data with it. (if `loop_reordering_after_fusion` is enabled, `loop_reordering` will also be considered.) Use the other nodes' tiling as the hints to split the reduction. If "hints" are too different from the original num_split, they'll be ignored. * Create new Scheduler nodes from the split reduction nodes. Remove the DEFERRED_SPLIT Reduction node and add the new Reduction nodes in Scheduler, and updating all related scheduler fields. (Hope I didn't miss any necessary fields in the scheduler...) * If there are new nodes added, re-run `fuse_nodes`. The second `fuse_nodes` should finish quickly as only the newly added reduction nodes may be fused into other nodes. Test Plan: ome test cases in my testing scripts ``` TORCHINDUCTOR_PROFILE_OUTPUT=/tmp/profile.txt TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING=1 TORCHINDUCTOR_DEFER_REDUCTION_SPLIT=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 TORCH_LOGS="fusion" buck run mode/opt scripts/shuqiyang/test_inductor:test_reduction 2>&1 | tee ~/test_reduction_log_4.txt ``` * Includes both the delayed scaling and luca's rowwise scaling. Next step is to test it in actual float8 training use case. ------ Unit test: Defer_reduction_split test: ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:defer_reduction_split ``` There was a test failure in loop_reordering, https://fburl.com/test/yo50vt6x. Related fusion logs: P1682598849. Also fixed it in this diff. ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering ``` ---- Test in Float8 Training ``` TORCHINDUCTOR_DEFER_REDUCTION_SPLIT=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 buck run mode/opt scripts/shuqiyang/test_inductor:test_float8 -- ~/local/tmp/20241120_test --dtype_filter float8 --scaling_type_input delayed --scaling_type_weight delayed --scaling_type_grad_output delayed 2>&1 | tee ~/test_compile ``` Trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.64960216286.json.gz&bucket=acadia Codegen kernels: P1683645463, P1683645752 Differential Revision: D66157846
90212ab to
c5208d1
Compare
|
This pull request was exported from Phabricator. Differential Revision: D66157846 |
|
thanks @y-sq ! Without activation checkpointing, I now see what I think is the optimal kernel patterns generated for fwd+bwd of a Float8Linear, measured with If we add AC (requires pytorch/ao#1354), then I still see a lot of missing fusions: |
|
@vkuzo thanks for sharing the SAC script. I'll check the SAC case. In SAC, is the following the ideal case for delayed scaling? Avtivations:
Weights:
Gradients:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good ! a few comments. would you do a perf run on the dashboard ? see https://pytorch.org/docs/stable/torch.compiler_performance_dashboard.html
i still need to review one of the functions you added.
torch/_inductor/scheduler.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, would you say more about this error ?
|
|
||
|
|
||
| if HAS_GPU: | ||
| torch.set_default_device(GPU_TYPE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, i would put this in setUpClass and revert to prior default device in tearDownClass
| # self.dtype represents the dst dtype | ||
| src_dtype: torch.dtype | ||
| reduction_hint: ReductionHint | ||
| input_node: Optional[IRNode] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can get this with reduction.data.get_read_names() - don't think we need to add an additional field for it.
| min_split_size = num_threads * min_elements_per_thread // 8 # 1024 | ||
| max_split_size = num_threads * max_elements_per_thread * 2 # 262144 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you say where these numbers are coming from and add comments/docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f there is already a partial reduction occurring, as with #136267, I think it always makes sense to do the same split. This should be a clear win and doesn't require a heuristic. What do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is already a partial reduction occurring, as with https://github.com/pytorch/pytorch/issues/136267, I think it always makes sense to do the same split.
Yes, this makes sense. I think I should have two types of split_hints..? - If it comes from a partial reduction, always follow it; if it comes from other tiled ops, only follow it if the split is not two different from the optimal one?
| rvals_per_thread * split_size | ||
| ) | ||
|
|
||
| def refine_split_num(split_num, reduction_numel_hint, split_hints) -> _IntLike: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment/docstring here ?
| reduction_type: str, | ||
| reduction_hint: ReductionHint = ReductionHint.DEFAULT, | ||
| input_node: Optional[IRNode] = None, | ||
| ignore_defer_reduction_split: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we are not yet supporting Welford, that's fine.
| OUTER = 1 | ||
| OUTER_TINY = 2 | ||
| DEFAULT = 3 | ||
| DEFERRED_SPLIT = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little weird for this to show up at runtime, since it's a compile time only construct.
| self.logged_slow_fusion: OrderedSet[Tuple[str, str]] = OrderedSet() | ||
| if config._pre_fusion_custom_pass is not None: | ||
| self.nodes = config._pre_fusion_custom_pass(self.nodes) | ||
| self.nodes = self.fuse_nodes(self.nodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I understand we want to do some amount of fusion prior to fusion of split nodes, however, the current fusion occurs so that we greedily fuse most profitable fusions first (saved global memory). And we are currently willing to do it for 10 iterations or until we did not do any fusions in the current round.
I'm not sure what the best approach is here but I don't think we want to do 20 total rounds. Potentially some smaller number of possible fusions, first, then split reductions, then fuse. And in first round we could consider just attempting fusing pointwise nodes.
| def replace_operation_buffer( | ||
| self, orig_node: ir.OperationBuffer, new_node: ir.OperationBuffer | ||
| ) -> None: | ||
| replaced_buf_name = new_node.get_name() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, glad you found this !
| else: | ||
| raise NotImplementedError(node) | ||
|
|
||
| def split_reduction_node(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need to review this function
|
How does this interact with the cooperative reductions from @jansel ? |
jansel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you talk a bit more about the motivation for this? Longer term I want to remove/disable split reductions in factor of cooperative reductions (which can fuse more things and works better with dynamic shapes).
Setting triton.cooperative_reductions=True might solve the same problem as this.
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| def split_reduction_single_node(idx, snode, all_nodes) -> None: | ||
| split_hints = [] | ||
| for other_node in all_nodes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need to iterate over all_nodes. see, here where we first group by buffer accesses.
| tiling = SIMDScheduling.select_tiling( | ||
| other_node.get_nodes(), numel1, rnumel1 | ||
| ) | ||
| split_hints.append(tiling[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why only tiling[0] here ?
| self.name_to_node[snode.get_name()] = new_scheduler_node | ||
| self.name_to_fused_node[snode.get_name()] = new_scheduler_node | ||
|
|
||
| # It's possible that there is no additional middle-layer nodes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meaning, we didn't split the reduction ? Can we check if there would be a split prior to creating a new reduction ?
| assert isinstance(middle_buffer, ir.ComputedBuffer) # linter | ||
| middle_buffer.layout = middle_buffer.layout.as_fixed() | ||
| middle_node = self.create_scheduler_node(middle_buffer) | ||
| self.nodes.append(middle_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we are appending this in a loop?
| # The node whose users are the orig_node (un-split reduction) should be updated to be the first middle reduction node | ||
| for node in self.nodes: | ||
| for buf in node.get_outputs(): | ||
| for user in buf.users: | ||
| if user.node.get_name() == snode.get_name(): | ||
| user.node = self.nodes[orig_num_scheduler_node] | ||
| # Update the dependencies (buf.usesrs) of the new reduction nodes | ||
| self.compute_dependencies( | ||
| self.nodes[orig_num_scheduler_node:] + [new_scheduler_node] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In each iteration here we're updating nodes/dependencies. this might be slow. Maybe it would make sense to collect all the buffers we're going to split first, then split, then update once ?
@Chillee , currently this change should have no interaction with the cooperative reductions. but I'll run some local tests to see if |
I'd expect cooperative reductions to be strictly better than split. I don't think we need to wait on the upstream Triton changes, though we will need to add a flag to cuda kernel launches. I think that can be done by generating our own kernel launch code (similar to AOTI). |
|
Hi @jansel and @eellison , thanks for your comments about the cooperative reductions. I did some tests to see if it can directly help the fp8 cases. Context of float8 training
This fusion of
This fusion doesn't happen also because I did some tests with ( Performance of a single reduction With Without It seems that the current performance of cooperative_reduction is worse than split_reduction. Does it need "the triton upstream update" or "a flag to cuda kernel launches" to fix the performance issue? The two fp8 cases In the second case, also can't be fused: (But I am not sure whether the fusion issues are easy or hard to fix. It might be easier to fix than split-reduction as there will be only one reduction kernel for cooperative_reduction...?) |
|
Updates of cooperative reductions performance (cc @eellison) I re-ran the single reduction case with With Also tried with With
The perf becomes better if the input shape is larger, but still slower than |
|
The other thing you could try is increasing/decreasing:
and (set them to be equal): pytorch/torch/_inductor/runtime/hints.py Line 16 in ecbb8a8
That heuristic is the same ideas as the split factor in split reductions, and currently it doesn't have a well tuned heuristic. I'd expect if you make RSPLIT match the split factor the performance will be the closer. |
|
I tried different TRITON_MAX_RSPLIT. With a proper RSPLIT value, the performance is very close to split_reduction. TRITON_MAX_RSPLIT = 256, TRITON_MAX_RSPLIT = 512, However, besides the performance issue, we still need additional efforts to make the fusion work for COOPERATIVE_REDUCTIONS..? |
For this case, we create a tiled pointwise kernel and are unable to fuse it to the reduction because the pointwise node tiling is not equal to That is the pointwise kernel is not loading over the tensor contiguously nor does the tiling match the output # of elems and the # of reduced of inputs for the reduction - (numel2, rnumel2). We don't currently support tiling reductions but @blaine-rister has a pr out for it here: #137243. I think in the end state with cooperative and tiled reductions both turned on we will be able to fuse this in a single kernel. That said, it might take a little while for this to be on by default. Since we are mostly targeting a specific workload maybe we could land am fx pass that is simpler and easier to maintain in the short term, either in tree or as custom pass. We already have one pass that targets the shared partial reduction. For that case, if we're targeting the joint graph saved activations, we'd need to do it as an fx pass regardless. For the two step reduction that enables fusion with the tiled pointwise op, would it be feasible to write this way in the user code temporarily ? potentially we could write a pass that looks for use uses with transposed tensors. |
|
@eellison thanks for the quick response. yes, I'll then work on the fx pass solution, and likely land that to torchao repo...? |
…torch#142273) Summary: (Since I am trying the other solution for pytorch#141082, I moved out the test case fixes from that pr to a separate pr to land first.) ----- Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference. The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0 ------- The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`. Before the change: `shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold` After the change: `shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused` ---- It's the same issue as fixed in pytorch#136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again. Test Plan: ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering ``` ----- Ran a float8 dynamic scaling training script to verify it e2e Reviewed By: eellison Differential Revision: D66906175
…42273) **Summary:** (Since I am trying the other solution for #141082, I moved out the test case fixes from that pr to a separate pr to land first.) ----- Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference. The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0 ------- The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`. Before the change: `shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold` After the change: `shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused` ---- It's the same issue as fixed in #136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again. **Test Plan:** ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering ``` And ran a float8 dynamic scaling training script to verify it e2e ----- Differential Revision: D66906175 Pull Request resolved: #142273 Approved by: https://github.com/eellison
…torch#142474) Summary: Re-land the pr. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`. ``` _____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________ Unexpected success ``` ------ (Since I am trying the other solution for pytorch#141082, I moved out the test case fixes from that pr to a separate pr to land first.) ----- Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference. The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0 ------- The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`. Before the change: `shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold` After the change: `shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused` ---- It's the same issue as fixed in pytorch#136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again. Test Plan: ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering ``` ----- Ran a float8 dynamic scaling training script to verify it e2e Differential Revision: D67012816
…torch#142474) Summary: Re-land the pr. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`. ``` _____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________ Unexpected success ``` ------ (Since I am trying the other solution for pytorch#141082, I moved out the test case fixes from that pr to a separate pr to land first.) ----- Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference. The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0 ------- The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`. Before the change: `shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold` After the change: `shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused` ---- It's the same issue as fixed in pytorch#136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again. Test Plan: ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering ``` ----- Ran a float8 dynamic scaling training script to verify it e2e Differential Revision: D67012816
…torch#142474) Summary: Re-land the pr. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`. ``` _____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________ Unexpected success ``` ------ (Since I am trying the other solution for pytorch#141082, I moved out the test case fixes from that pr to a separate pr to land first.) ----- Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference. The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0 ------- The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`. Before the change: `shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold` After the change: `shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused` ---- It's the same issue as fixed in pytorch#136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again. Test Plan: ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering ``` ----- Ran a float8 dynamic scaling training script to verify it e2e Differential Revision: D67012816
…42474) Summary: Re-land the pr. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`. ``` _____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________ Unexpected success ``` ------ (Since I am trying the other solution for #141082, I moved out the test case fixes from that pr to a separate pr to land first.) ----- Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference. The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0 ------- The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`. Before the change: `shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold` After the change: `shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused` ---- It's the same issue as fixed in #136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again. Test Plan: ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering ``` ----- Ran a float8 dynamic scaling training script to verify it e2e Reviewed By: shunting314, sijiac Differential Revision: D67012816
…42474) Summary: **Re-land the pr**. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`. ``` _____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________ Unexpected success ``` ------ (Since I am trying the other solution for #141082, I moved out the test case fixes from that pr to a separate pr to land first.) ----- Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference. The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0 ------- The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`. Before the change: `shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold` After the change: `shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused` ---- It's the same issue as fixed in #136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again. Test Plan: ``` buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering ``` ----- Ran a float8 dynamic scaling training script to verify it e2e Differential Revision: D67012816 Pull Request resolved: #142474 Approved by: https://github.com/eellison, https://github.com/sijiac, https://github.com/shunting314
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |


Context
We mainly want to address the two cases that are commonly used in float8 training:
This fusion of
yandz_z_tfailed due to "invalid tiling" (#128063)This fusion doesn't happen also because
yand first-level reduction ofzhas a different tiling (#136267)To fix the two cases, we attempted to defer the split of reduction after fusion, so that the reductio split can take its neighbor op into consideration to allow the fusions to happen.
Summary of the changes:
defer_reduction_split. When it's enabled, ifnum_splitsgets a>1result, returnReductionHint.DEFERRED_SPLIT, 1instead.ReductionHint.DEFERRED_SPLIT, ignore it for now. After all non-DEFERRED_SPLIT nodes are fused, split DEFERRED_SPLIT Reduction nodes.split_reduction_nodes. However, there are cases that the tiled operations will only occur after fusion. So I chose to runfuse_nodesfirst, and then used the fused nodes to split the reduction nodes.For example,
loop_reordering_after_fusionis enabled,loop_reorderingwill also be considered.)Use the other nodes' tiling as the hints to split the reduction. If "hints" are too different from the original num_split, they'll be ignored.
fuse_nodes. The secondfuse_nodesshould finish quickly as only the newly added reduction nodes may be fused into other nodes.Test Plan:
Some test cases in my testing scripts in fbcode (in the diff)
Unit test:
Defer_reduction_split test:
Differential Revision: D66157846
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov