Skip to content

Conversation

@y-sq
Copy link
Contributor

@y-sq y-sq commented Nov 20, 2024

Context
We mainly want to address the two cases that are commonly used in float8 training:

  1. reduction with a tiled pointwise, for example:
def test_1(input_x):
    y = input_x.abs().max()

    z = input_x / 10.0
    z_t = z.t().contiguous().t()  # `z` and `z_t` will be fused into a tiled pointwise

    return y, z, z_t

This fusion of y and z_z_t failed due to "invalid tiling" (#128063)

  1. reduction with a "partial reduction", for example:
def test_2(x):
    y = x.abs().max(dim=-1)
    z = x.abs().max()   # we want the first-level reduction of `z` can be fused with `y`.
    return y[0], z

This fusion doesn't happen also because y and first-level reduction of z has a different tiling (#136267)

To fix the two cases, we attempted to defer the split of reduction after fusion, so that the reductio split can take its neighbor op into consideration to allow the fusions to happen.

Summary of the changes:

  • Added a config option defer_reduction_split. When it's enabled, if num_splits gets a >1 result, return ReductionHint.DEFERRED_SPLIT, 1 instead.
  • In scheduler, when fusing nodes, if a node is a ReductionHint.DEFERRED_SPLIT, ignore it for now. After all non-DEFERRED_SPLIT nodes are fused, split DEFERRED_SPLIT Reduction nodes.
    • I thought about a better place to put split_reduction_nodes. However, there are cases that the tiled operations will only occur after fusion. So I chose to run fuse_nodes first, and then used the fused nodes to split the reduction nodes.
      For example,
# op0. the reduction, read: arg0_1
y = x.abs().max()
# op1. read: arg0_1; write: buf0; no tiling
> z = x / 10.0   
# op2. read: transposed buf0              
z_t = z.t().contiguous().t()  

# Before fusion of op1 and op2:
Only op1 has shared data with op0, however op1 has no tiling
# After fusion of op1 and op2:
fused op0_1: read arg0_1, and has a tiling that the reduction (op0) can try to match
  • For a DEFERRED_SPLIT Reduction node, find all nodes that has shared data with it. (if loop_reordering_after_fusion is enabled, loop_reordering will also be considered.)
    Use the other nodes' tiling as the hints to split the reduction. If "hints" are too different from the original num_split, they'll be ignored.
  • Create new Scheduler nodes from the split reduction nodes. Remove the DEFERRED_SPLIT Reduction node and add the new Reduction nodes in Scheduler, and updating all related scheduler fields. (Hope I didn't miss any necessary fields in the scheduler...)
  • If there are new nodes added, re-run fuse_nodes. The second fuse_nodes should finish quickly as only the newly added reduction nodes may be fused into other nodes.

Test Plan:
Some test cases in my testing scripts in fbcode (in the diff)

TORCHINDUCTOR_PROFILE_OUTPUT=/tmp/profile.txt TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING=1  TORCHINDUCTOR_DEFER_REDUCTION_SPLIT=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 TORCH_LOGS="fusion" buck run mode/opt scripts/shuqiyang/test_inductor:test_reduction 2>&1 | tee ~/test_reduction_log_4.txt
  • Includes both the delayed scaling and luca's rowwise scaling. Next step is to test it in actual float8 training use case.

Unit test:
Defer_reduction_split test:

buck2 test 'fbcode//mode/opt' caffe2/test/inductor:defer_reduction_split

Differential Revision: D66157846

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/141082

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit c5208d1 with merge base ecf3bae (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66157846

self.name_to_buf.update({
buf.get_name() : buf for buf in middle_node.get_outputs()
})
self.name_to_fused_node[middle_node.get_name()] = middle_node
Copy link
Contributor Author

@y-sq y-sq Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hope I didn't overlook updating any other necessary fields in the scheduler...

@y-sq
Copy link
Contributor Author

y-sq commented Nov 20, 2024

@pytorchbot label "topic: not user facing"

@vkuzo
Copy link
Contributor

vkuzo commented Nov 21, 2024

Nice! I tried this on a torchao test case and get the following error:

> cd torchao
> TORCHINDUCTOR_DEFER_REDUCTION_SPLIT=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20241120_test --dtype_filter float8 --scaling_type_input delayed --scaling_type_weight delayed --scaling_type_grad_output delayed
...
  File "/data/users/vasiliy/pytorch/torch/_inductor/scheduler.py", line 1778, in __init__
    self._init(nodes)
  File "/data/users/vasiliy/pytorch/torch/_inductor/scheduler.py", line 1851, in _init
    has_new_split_nodes = self.split_reduction_node()
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/vasiliy/pytorch/torch/_inductor/scheduler.py", line 2031, in split_reduction_node
    self.compute_ancestors()
  File "/data/users/vasiliy/pytorch/torch/_inductor/scheduler.py", line 2406, in compute_ancestors
    ancestors |= name_to_ancestors[dep_node_name]
                 ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: 'op8'

@y-sq
Copy link
Contributor Author

y-sq commented Nov 21, 2024

@vkuzo thanks for sharing the error. Let me look into it.

Updates: I updated the pr to fix this error, also attached the trace + generated kernels in the diff's test plan.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66157846

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, would you say more about this error ?

Copy link
Contributor Author

@y-sq y-sq Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main change here is to modify the condition of calling loop_reordering from shared_data_score == 0 to shared_data_score < config.score_fusion_memory_threshold.

In the failed test (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425), the two pointwise nodes originally have a shared read data (the scaling factor, which is just a scalar). The shared_data_score is not 0 but smaller than config.loop_ordering_after_fusion.

Before the change:
shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold
After the change:
shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused

Summary:

* Added a config option `defer_reduction_split`. When it's enabled, if `num_splits` gets a `>1` result, return `ReductionHint.DEFERRED_SPLIT, 1` instead.
* In scheduler, when fusing nodes, if a node is a `ReductionHint.DEFERRED_SPLIT`, ignore it for now. After all non-DEFERRED_SPLIT nodes are fused, split DEFERRED_SPLIT Reduction nodes.
  * I thought about a better place to put `split_reduction_nodes`. However, there are cases that the tiled operations will only occur after fusion. So I chose to run `fuse_nodes` first, and then used the fused nodes to split the reduction nodes.
For example,
```
# op0. the reduction, read: arg0_1
y = x.abs().max()
# op1. read: arg0_1; write: buf0; no tiling
> z = x / 10.0   
# op2. read: transposed buf0              
z_t = z.t().contiguous().t()  

# Before fusion of op1 and op2:
Only op1 has shared data with op0, however op1 has no tiling
# After fusion of op1 and op2:
fused op0_1: read arg0_1, and has a tiling that the reduction (op0) can try to match
```
* For a  DEFERRED_SPLIT Reduction node, find all nodes that has shared data with it. (if `loop_reordering_after_fusion` is enabled, `loop_reordering` will also be considered.) 
Use the other nodes' tiling as the hints to split the reduction. If "hints" are too different from the original num_split, they'll be ignored.
*  Create new Scheduler nodes from the split reduction nodes. Remove the DEFERRED_SPLIT Reduction node and add the new Reduction nodes in Scheduler, and updating all related scheduler fields. (Hope I didn't miss any necessary fields in the scheduler...)
*  If there are new nodes added, re-run `fuse_nodes`. The second `fuse_nodes` should finish quickly as only the newly added reduction nodes may be fused into other nodes.

Test Plan:
ome test cases in my testing scripts
```
TORCHINDUCTOR_PROFILE_OUTPUT=/tmp/profile.txt TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING=1  TORCHINDUCTOR_DEFER_REDUCTION_SPLIT=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 TORCH_LOGS="fusion" buck run mode/opt scripts/shuqiyang/test_inductor:test_reduction 2>&1 | tee ~/test_reduction_log_4.txt
```
* Includes both the delayed scaling and luca's rowwise scaling. Next step is to test it in actual float8 training use case.
------
Unit test:
Defer_reduction_split test:
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:defer_reduction_split
```

There was a test failure in loop_reordering, https://fburl.com/test/yo50vt6x. Related fusion logs: P1682598849. Also fixed it in this diff.
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
----
Test in Float8 Training
```
TORCHINDUCTOR_DEFER_REDUCTION_SPLIT=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1  buck run mode/opt scripts/shuqiyang/test_inductor:test_float8 --  ~/local/tmp/20241120_test --dtype_filter float8 --scaling_type_input delayed --scaling_type_weight delayed --scaling_type_grad_output delayed  2>&1 | tee ~/test_compile
```
Trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/trace.64960216286.json.gz&bucket=acadia
Codegen kernels: P1683645463, P1683645752

Differential Revision: D66157846
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66157846

@vkuzo
Copy link
Contributor

vkuzo commented Nov 26, 2024

thanks @y-sq ! Without activation checkpointing, I now see what I think is the optimal kernel patterns generated for fwd+bwd of a Float8Linear, measured with

TORCHINDUCTOR_DEFER_REDUCTION_SPLIT=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20241120_test --dtype_filter float8 --scaling_type_input delayed --scaling_type_weight delayed --scaling_type_grad_output delayed

If we add AC (requires pytorch/ao#1354), then I still see a lot of missing fusions:

TORCHINDUCTOR_DEFER_REDUCTION_SPLIT=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20241120_test --dtype_filter float8 --scaling_type_input delayed --scaling_type_weight delayed --scaling_type_grad_output delayed --enable_activation_checkpointing True

@y-sq
Copy link
Contributor Author

y-sq commented Nov 26, 2024

@vkuzo thanks for sharing the SAC script. I'll check the SAC case.

In SAC, is the following the ideal case for delayed scaling?

Avtivations:

  • In fwd, amax + cast (without transpose); amax is save for bwd, and fp8 tensors are not
  • In bwd, transposed_cast (without amax because it's saved from fwd)

Weights:

  • In FSDP case, should be the same as activations - fp8 weights are not saved for bwd, amax are saved for bwd.

Gradients:

  • Not affected by SAC

@vkuzo
Copy link
Contributor

vkuzo commented Nov 26, 2024

yep that sounds right!

this is what I see now with SAC:

forward

seems that max(abs(tensor)) is not fused with the cast to float8

Screenshot 2024-11-26 at 2 31 49 PM

backward

looks good (I think)

Screenshot 2024-11-26 at 2 32 17 PM

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good ! a few comments. would you do a perf run on the dashboard ? see https://pytorch.org/docs/stable/torch.compiler_performance_dashboard.html

i still need to review one of the functions you added.

cc @jansel, @Chillee for further thoughts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, would you say more about this error ?



if HAS_GPU:
torch.set_default_device(GPU_TYPE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, i would put this in setUpClass and revert to prior default device in tearDownClass

# self.dtype represents the dst dtype
src_dtype: torch.dtype
reduction_hint: ReductionHint
input_node: Optional[IRNode] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get this with reduction.data.get_read_names() - don't think we need to add an additional field for it.

Comment on lines +1138 to +1139
min_split_size = num_threads * min_elements_per_thread // 8 # 1024
max_split_size = num_threads * max_elements_per_thread * 2 # 262144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say where these numbers are coming from and add comments/docstring?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f there is already a partial reduction occurring, as with #136267, I think it always makes sense to do the same split. This should be a clear win and doesn't require a heuristic. What do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is already a partial reduction occurring, as with https://github.com/pytorch/pytorch/issues/136267, I think it always makes sense to do the same split.
Yes, this makes sense. I think I should have two types of split_hints..? - If it comes from a partial reduction, always follow it; if it comes from other tiled ops, only follow it if the split is not two different from the optimal one?

rvals_per_thread * split_size
)

def refine_split_num(split_num, reduction_numel_hint, split_hints) -> _IntLike:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment/docstring here ?

reduction_type: str,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
input_node: Optional[IRNode] = None,
ignore_defer_reduction_split: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we are not yet supporting Welford, that's fine.

OUTER = 1
OUTER_TINY = 2
DEFAULT = 3
DEFERRED_SPLIT = 3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little weird for this to show up at runtime, since it's a compile time only construct.

self.logged_slow_fusion: OrderedSet[Tuple[str, str]] = OrderedSet()
if config._pre_fusion_custom_pass is not None:
self.nodes = config._pre_fusion_custom_pass(self.nodes)
self.nodes = self.fuse_nodes(self.nodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I understand we want to do some amount of fusion prior to fusion of split nodes, however, the current fusion occurs so that we greedily fuse most profitable fusions first (saved global memory). And we are currently willing to do it for 10 iterations or until we did not do any fusions in the current round.

I'm not sure what the best approach is here but I don't think we want to do 20 total rounds. Potentially some smaller number of possible fusions, first, then split reductions, then fuse. And in first round we could consider just attempting fusing pointwise nodes.

def replace_operation_buffer(
self, orig_node: ir.OperationBuffer, new_node: ir.OperationBuffer
) -> None:
replaced_buf_name = new_node.get_name()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, glad you found this !

else:
raise NotImplementedError(node)

def split_reduction_node(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to review this function

@Chillee
Copy link
Collaborator

Chillee commented Nov 27, 2024

How does this interact with the cooperative reductions from @jansel ?

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you talk a bit more about the motivation for this? Longer term I want to remove/disable split reductions in factor of cooperative reductions (which can fuse more things and works better with dynamic shapes).

Setting triton.cooperative_reductions=True might solve the same problem as this.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel the motivation #128063. i think for fp8 they were hoping to have a working solution in the nearer term. my understanding with cooperative reductions is it wouldn't necessarily be strictly more performant and also would require upstream triton changes for cooperative launches


def split_reduction_single_node(idx, snode, all_nodes) -> None:
split_hints = []
for other_node in all_nodes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need to iterate over all_nodes. see, here where we first group by buffer accesses.

tiling = SIMDScheduling.select_tiling(
other_node.get_nodes(), numel1, rnumel1
)
split_hints.append(tiling[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only tiling[0] here ?

self.name_to_node[snode.get_name()] = new_scheduler_node
self.name_to_fused_node[snode.get_name()] = new_scheduler_node

# It's possible that there is no additional middle-layer nodes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning, we didn't split the reduction ? Can we check if there would be a split prior to creating a new reduction ?

assert isinstance(middle_buffer, ir.ComputedBuffer) # linter
middle_buffer.layout = middle_buffer.layout.as_fixed()
middle_node = self.create_scheduler_node(middle_buffer)
self.nodes.append(middle_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we are appending this in a loop?

Comment on lines +2039 to +2047
# The node whose users are the orig_node (un-split reduction) should be updated to be the first middle reduction node
for node in self.nodes:
for buf in node.get_outputs():
for user in buf.users:
if user.node.get_name() == snode.get_name():
user.node = self.nodes[orig_num_scheduler_node]
# Update the dependencies (buf.usesrs) of the new reduction nodes
self.compute_dependencies(
self.nodes[orig_num_scheduler_node:] + [new_scheduler_node]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In each iteration here we're updating nodes/dependencies. this might be slow. Maybe it would make sense to collect all the buffers we're going to split first, then split, then update once ?

@y-sq
Copy link
Contributor Author

y-sq commented Nov 27, 2024

How does this interact with the cooperative reductions from @jansel ?

@Chillee , currently this change should have no interaction with the cooperative reductions.
The option "defer_reduction_split" only takes effects when the reduction should be split. When cooperative_reductions is enabled (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py#L484), Reduction.num_splits would always have should_split == False, https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L1049 (V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT) is True with cooperative_reductions, https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L3457)

but I'll run some local tests to see if cooperative_reductions itself is enough to handle the fp8 use case.

@jansel
Copy link
Contributor

jansel commented Nov 30, 2024

@jansel the motivation #128063. i think for fp8 they were hoping to have a working solution in the nearer term. my understanding with cooperative reductions is it wouldn't necessarily be strictly more performant and also would require upstream triton changes for cooperative launches

I'd expect cooperative reductions to be strictly better than split. I don't think we need to wait on the upstream Triton changes, though we will need to add a flag to cuda kernel launches. I think that can be done by generating our own kernel launch code (similar to AOTI).

@y-sq
Copy link
Contributor Author

y-sq commented Dec 2, 2024

Hi @jansel and @eellison , thanks for your comments about the cooperative reductions. I did some tests to see if it can directly help the fp8 cases.

Context of float8 training
For context about our needs in float8 training, we have the two cases that we want to handle:

  1. reduction with a tiled pointwise, for example:
def test_1(input_x):
    y = input_x.abs().max()

    z = input_x / 10.0
    z_t = z.t().contiguous().t()  # `z` and `z_t` will be fused into a tiled pointwise

    return y, z, z_t

test = torch.compile(test)
x = torch.randn(3072, 4096, device="cuda") / 10.0
y, z, z_t = test(x)

This fusion of y and z_z_t failed due to "invalid tiling" (#128063)

  1. reduction with a "partial reduction", for example:
def test_2(x):
    y = x.abs().max(dim=-1)
    z = x.abs().max()   # we want the first-level reduction of `z` can be fused with `y`.
    return y[0], z

test = torch.compile(test)
x = torch.randn(3072, 4096, device="cuda")
z = test(x)

This fusion doesn't happen also because y and first-level reduction of z has a different tiling (#136267)


I did some tests with TORCHINDUCTOR_COOPERATIVE_REDUCTIONS=1:

(TORCHINDUCTOR_PROFILE_OUTPUT=/tmp/profile.txt TORCHINDUCTOR_PROFILE=1 TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING=1 TORCH_LOGS="fusion" TORCHINDUCTOR_COOPERATIVE_REDUCTIONS=1 TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1 ...)

Performance of a single reduction

def test(x):
        z = x.abs().max()
        return z

test = torch.compile(test)
x = torch.randn(3072, 4096, device="cuda")
z = test(x)

With TORCHINDUCTOR_COOPERATIVE_REDUCTIONS=1, I got:

TRITON KERNELS BANDWIDTH INFO (/tmp/torchinductor_shuqiyang/tmpzov4m_tx/7u/c7ulsfjspmgzwufvmj6qik5zc2fedgdveqvzn7hxvv7sohhu3utl.py)
0.038ms    	0.050 GB 	 1329.40GB/s 	 100.00% 	 triton_unk_fused_abs_max_0
SUMMARY (/tmp/torchinductor_shuqiyang/tmpzov4m_tx/7u/c7ulsfjspmgzwufvmj6qik5zc2fedgdveqvzn7hxvv7sohhu3utl.py)
0.04ms   	 0.05 GB	 1329.40GB/s

Without TORCHINDUCTOR_COOPERATIVE_REDUCTIONS, using the split-reduction, I got:

TRITON KERNELS BANDWIDTH INFO (/tmp/torchinductor_shuqiyang/tmpyhnx8ovm/rk/crkiazuwvplph2xsey4jrrk454uw2imrohmlqkgoeswrqcncacqi.py)
0.029ms    	0.050 GB 	 1727.88GB/s 	 92.89% 	 triton_red_fused_abs_max_0
0.002ms    	0.000 GB 	    0.92GB/s 	 7.11% 	 triton_per_fused_abs_max_1
SUMMARY (/tmp/torchinductor_shuqiyang/tmpyhnx8ovm/rk/crkiazuwvplph2xsey4jrrk454uw2imrohmlqkgoeswrqcncacqi.py)
0.03ms   	 0.05 GB	 1605.07GB/s

It seems that the current performance of cooperative_reduction is worse than split_reduction. Does it need "the triton upstream update" or "a flag to cuda kernel launches" to fix the performance issue?

The two fp8 cases
In the first case, they still can't be fused:

scheduler.py:2491] [0/0] [__fusion] ===== attempting fusion (2/10): 2 nodes =====
scheduler.py:2781] [0/0] [__fusion] fuse_nodes_once, candidates:
scheduler.py:2783] [0/0] [__fusion]   SchedulerNode(name='op0'), Reduction(['[3072, 4096]', 'max', 'origins=OrderedSet([abs_1, max_1])'])
scheduler.py:2783] [0/0] [__fusion]   FusedSchedulerNode(nodes=op1_op2), snodes: ["SchedulerNode(name='op1'), Pointwise(['[3072, 4096]', 'origins=OrderedSet([div])'])", "SchedulerNode(name='op2'), Pointwise(['[4096, 3072]', 'origins=OrderedSet([clone])'])"]
scheduler.py:747] [0/0] [__fusion] cannot fuse op1_op2 with op0: invalid tiling for reduction
scheduler.py:2909] [0/0] [__fusion] found 0 possible fusions
scheduler.py:2498] [0/0] [__fusion] completed fusion round (2/10): fused 2 nodes into 2 nodes
scheduler.py:2498] [0/0] [__fusion] 
scheduler.py:2505] [0/0] [__fusion] ===== fusion complete (2 iterations) =====

In the second case, also can't be fused:

scheduler.py:2491] [0/0] [__fusion] ===== attempting fusion (1/10): 2 nodes =====
scheduler.py:2781] [0/0] [__fusion] fuse_nodes_once, candidates:
scheduler.py:2783] [0/0] [__fusion]   SchedulerNode(name='op0'), Reduction(['[4096]', 'max', 'origins=OrderedSet([abs_1, max_1])'])
scheduler.py:2783] [0/0] [__fusion]   SchedulerNode(name='op2'), Reduction(['[3072, 4096]', 'max', 'origins=OrderedSet([max_2, abs_2])'])
scheduler.py:747] [0/0] [__fusion] cannot fuse op0 with op2: numel/rnumel mismatch (reduce) (3072, 1), (4096, 12582912)
scheduler.py:2909] [0/0] [__fusion] found 0 possible fusions
scheduler.py:2498] [0/0] [__fusion] completed fusion round (1/10): fused 2 nodes into 2 nodes
scheduler.py:2498] [0/0] [__fusion] 
scheduler.py:2505] [0/0] [__fusion] ===== fusion complete (1 iterations) =====

(But I am not sure whether the fusion issues are easy or hard to fix. It might be easier to fix than split-reduction as there will be only one reduction kernel for cooperative_reduction...?)

@y-sq
Copy link
Contributor Author

y-sq commented Dec 3, 2024

Updates of cooperative reductions performance (cc @eellison)

I re-ran the single reduction case with TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 as @eellison suggested, but the performance didn't change much:

With TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1:
Input shape = torch.randn(3072, 4096, device="cuda")

TRITON KERNELS BANDWIDTH INFO (/tmp/torchinductor_shuqiyang/tmpw9ql5q9u/oj/coju6dmdfrvyhochbsepy7nghcdrranfrgyidkyupddm3awkch4n.py)
0.038ms    	0.050 GB 	 1331.60GB/s 	 100.00% 	 triton_unk_fused_abs_max_0
SUMMARY (/tmp/torchinductor_shuqiyang/tmpw9ql5q9u/oj/coju6dmdfrvyhochbsepy7nghcdrranfrgyidkyupddm3awkch4n.py)
0.04ms   	 0.05 GB	 1331.60GB/s

Also tried with TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS=1 TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS=5:

TRITON KERNELS BANDWIDTH INFO (/tmp/torchinductor_shuqiyang/tmpa3ocs9c7/mh/cmhhybrxlx3l7cir5qmowv5vi6sfmswfr3parh6eirkrhaddwdfq.py)
0.038ms    	0.050 GB 	 1331.87GB/s 	 100.00% 	 triton_unk_fused_abs_max_0
SUMMARY (/tmp/torchinductor_shuqiyang/tmpa3ocs9c7/mh/cmhhybrxlx3l7cir5qmowv5vi6sfmswfr3parh6eirkrhaddwdfq.py)
0.04ms   	 0.05 GB	 1331.87GB/s

With TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1, I tested different shapes:
x = torch.randn(4096, 4096, device="cuda")

TRITON KERNELS BANDWIDTH INFO (/tmp/torchinductor_shuqiyang/tmp1lwjlpzd/rh/crhu45jkzbkomsaitne6k37rbbggptdp7zpquqovemp2iezlguqf.py)
0.042ms    	0.067 GB 	 1580.72GB/s 	 100.00% 	 triton_unk_fused_abs_max_0
SUMMARY (/tmp/torchinductor_shuqiyang/tmp1lwjlpzd/rh/crhu45jkzbkomsaitne6k37rbbggptdp7zpquqovemp2iezlguqf.py)
0.04ms   	 0.07 GB	 1580.72GB/s

x = torch.randn(8192, 8192, device="cuda")

TRITON KERNELS BANDWIDTH INFO (/tmp/torchinductor_shuqiyang/tmp_33mlbi4/wn/cwnjagz4dofzqv6qbmkespi6mca6dcvs7sufganbqiga23y5icd3.py)
0.149ms    	0.268 GB 	 1799.95GB/s 	 100.00% 	 triton_unk_fused_abs_max_0
SUMMARY (/tmp/torchinductor_shuqiyang/tmp_33mlbi4/wn/cwnjagz4dofzqv6qbmkespi6mca6dcvs7sufganbqiga23y5icd3.py)
0.15ms   	 0.27 GB	 1799.95GB/s

The perf becomes better if the input shape is larger, but still slower than split_reduction

@jansel
Copy link
Contributor

jansel commented Dec 3, 2024

The other thing you could try is increasing/decreasing:


and (set them to be equal):
TRITON_MAX_RSPLIT = 64

That heuristic is the same ideas as the split factor in split reductions, and currently it doesn't have a well tuned heuristic. I'd expect if you make RSPLIT match the split factor the performance will be the closer.

@y-sq
Copy link
Contributor Author

y-sq commented Dec 3, 2024

I tried different TRITON_MAX_RSPLIT. With a proper RSPLIT value, the performance is very close to split_reduction.

TRITON_MAX_RSPLIT = 256,

TRITON KERNELS BANDWIDTH INFO (/tmp/torchinductor_shuqiyang/tmpf4fc6aqq/fi/cfikuh67balrjcaj27uqvppm44ot7hk6vq4llisvdfwv2ke3ycs5.py)
0.032ms    	0.050 GB 	 1590.97GB/s 	 100.00% 	 triton_unk_fused_abs_max_0
SUMMARY (/tmp/torchinductor_shuqiyang/tmpf4fc6aqq/fi/cfikuh67balrjcaj27uqvppm44ot7hk6vq4llisvdfwv2ke3ycs5.py)
0.03ms   	 0.05 GB	 1590.97GB/s

TRITON_MAX_RSPLIT = 512,

TRITON KERNELS BANDWIDTH INFO (/tmp/torchinductor_shuqiyang/tmp3qkw7bjj/6a/c6azb7rdb5pfwclyamp3pqztdqqrxccpjsuwgheanfnws6z5ih72.py)
0.031ms    	0.050 GB 	 1609.81GB/s 	 100.00% 	 triton_unk_fused_abs_max_0
SUMMARY (/tmp/torchinductor_shuqiyang/tmp3qkw7bjj/6a/c6azb7rdb5pfwclyamp3pqztdqqrxccpjsuwgheanfnws6z5ih72.py)
0.03ms   	 0.05 GB	 1609.81GB/s

However, besides the performance issue, we still need additional efforts to make the fusion work for COOPERATIVE_REDUCTIONS..?

@eellison
Copy link
Contributor

eellison commented Dec 5, 2024

def test_1(input_x):
    y = input_x.abs().max()

    z = input_x / 10.0
    z_t = z.t().contiguous().t()  # `z` and `z_t` will be fused into a tiled pointwise

    return y, z, z_t

test = torch.compile(test)
x = torch.randn(3072, 4096, device="cuda") / 10.0
y, z, z_t = test(x)

For this case, we create a tiled pointwise kernel and are unable to fuse it to the reduction because the pointwise node tiling is not equal to

(numel1, 1),
(numel2, rnumel2, 1),

That is the pointwise kernel is not loading over the tensor contiguously nor does the tiling match the output # of elems and the # of reduced of inputs for the reduction - (numel2, rnumel2).

We don't currently support tiling reductions but @blaine-rister has a pr out for it here: #137243. I think in the end state with cooperative and tiled reductions both turned on we will be able to fuse this in a single kernel.

That said, it might take a little while for this to be on by default. Since we are mostly targeting a specific workload maybe we could land am fx pass that is simpler and easier to maintain in the short term, either in tree or as custom pass.

We already have one pass that targets the shared partial reduction. For that case, if we're targeting the joint graph saved activations, we'd need to do it as an fx pass regardless.

For the two step reduction that enables fusion with the tiled pointwise op, would it be feasible to write this way in the user code temporarily ? potentially we could write a pass that looks for use uses with transposed tensors.

@y-sq
Copy link
Contributor Author

y-sq commented Dec 5, 2024

@eellison thanks for the quick response. yes, I'll then work on the fx pass solution, and likely land that to torchao repo...?
(cc @vkuzo the idea of the fx pass solution is also split the reduction to two-levels to allow Inductor for further fusion, so it will also handle the fusion of preceding ops.)

y-sq added a commit to y-sq/pytorch that referenced this pull request Dec 9, 2024
…torch#142273)

Summary:

(Since I am trying the other solution for pytorch#141082, I moved out the test case fixes from that pr to a separate pr to land first.)

-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.

The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0

-------

The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.


Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`

----
It's the same issue as fixed in pytorch#136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.

Test Plan:
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
-----
Ran a float8 dynamic scaling training script to verify it e2e

Reviewed By: eellison

Differential Revision: D66906175
pytorchmergebot pushed a commit that referenced this pull request Dec 10, 2024
…42273)

**Summary:**
(Since I am trying the other solution for #141082, I moved out the test case fixes from that pr to a separate pr to land first.)

-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.

The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0

-------

The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.

Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`

----
It's the same issue as fixed in #136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.

**Test Plan:**
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
And ran a float8 dynamic scaling training script to verify it e2e

-----

Differential Revision: D66906175

Pull Request resolved: #142273
Approved by: https://github.com/eellison
y-sq added a commit to y-sq/pytorch that referenced this pull request Dec 10, 2024
…torch#142474)

Summary:

Re-land the pr. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`.

```
_____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________
Unexpected success
```
------
(Since I am trying the other solution for pytorch#141082, I moved out the test case fixes from that pr to a separate pr to land first.)

-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.

The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0

-------

The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.


Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`

----
It's the same issue as fixed in pytorch#136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.

Test Plan:
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
-----
Ran a float8 dynamic scaling training script to verify it e2e

Differential Revision: D67012816
pytorchmergebot pushed a commit to y-sq/pytorch that referenced this pull request Dec 11, 2024
…torch#142474)

Summary:

Re-land the pr. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`.

```
_____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________
Unexpected success
```
------
(Since I am trying the other solution for pytorch#141082, I moved out the test case fixes from that pr to a separate pr to land first.)

-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.

The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0

-------

The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.


Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`

----
It's the same issue as fixed in pytorch#136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.

Test Plan:
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
-----
Ran a float8 dynamic scaling training script to verify it e2e

Differential Revision: D67012816
pytorchmergebot pushed a commit to y-sq/pytorch that referenced this pull request Dec 11, 2024
…torch#142474)

Summary:

Re-land the pr. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`.

```
_____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________
Unexpected success
```
------
(Since I am trying the other solution for pytorch#141082, I moved out the test case fixes from that pr to a separate pr to land first.)

-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.

The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0

-------

The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.


Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`

----
It's the same issue as fixed in pytorch#136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.

Test Plan:
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
-----
Ran a float8 dynamic scaling training script to verify it e2e

Differential Revision: D67012816
pytorch-bot bot pushed a commit that referenced this pull request Dec 16, 2024
…42474)

Summary:

Re-land the pr. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`.

```
_____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________
Unexpected success
```
------
(Since I am trying the other solution for #141082, I moved out the test case fixes from that pr to a separate pr to land first.)

-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.

The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0

-------

The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.


Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`

----
It's the same issue as fixed in #136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.

Test Plan:
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
-----
Ran a float8 dynamic scaling training script to verify it e2e

Reviewed By: shunting314, sijiac

Differential Revision: D67012816
pytorchmergebot pushed a commit that referenced this pull request Dec 17, 2024
…42474)

Summary:
**Re-land the pr**. The previous one was reverted because of a test failure on SM89. The fix is just removing `xfailIfSM89`.

```
_____________________ LoopOrderingTest.test_fp8_pattern_2 ______________________
Unexpected success
```
------
(Since I am trying the other solution for #141082, I moved out the test case fixes from that pr to a separate pr to land first.)

-----
Testing float8 dynamic scaling case with `TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION=1` didn't make any difference.

The test case for fp8 (https://github.com/pytorch/pytorch/blob/main/test/inductor/test_loop_ordering.py#L425) is also failing, https://www.internalfb.com/intern/test/844425111960859?ref_report_id=0

-------

The main change here is to modify the condition of calling `loop_reordering` from `shared_data_score == 0` to `shared_data_score < config.score_fusion_memory_threshold`.

Before the change:
`shared_data_score > 0 -> won't loop_reorder -> can't fused because of shared_data_score < config.score_fusion_memory_threshold`
After the change:
`shared_data_score > 0 -> loop_reorder (shared_data_score < config.score_fusion_memory_threshold) -> get a larger shared_data_score -> fused`

----
It's the same issue as fixed in #136782. But the condition to call loop_reorder might be changed later, causing the test case to fail again.

Test Plan:
```
buck2 test 'fbcode//mode/opt' caffe2/test/inductor:loop_ordering
```
-----
Ran a float8 dynamic scaling training script to verify it e2e

Differential Revision: D67012816

Pull Request resolved: #142474
Approved by: https://github.com/eellison, https://github.com/sijiac, https://github.com/shunting314
@github-actions
Copy link
Contributor

github-actions bot commented Feb 3, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Feb 3, 2025
@github-actions github-actions bot closed this Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Inductor] Fusion of Tiled Point-Wise and Reduction Operators PT2 should leverage partial reductions to speed up larger reductions

6 participants