Skip to content

Commit 995e307

Browse files
janselpytorchmergebot
authored andcommitted
[inductor] Fix for "Failed to find static RBLOCK" (#141434)
Summary: I expect this to fix https://fb.workplace.com/groups/1075192433118967/permalink/1547962839175255/ Test Plan: Ask poster to confirm fix Differential Revision: D66413828 Pull Request resolved: #141434 Approved by: https://github.com/ezyang
1 parent f6eeab7 commit 995e307

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3179,18 +3179,28 @@ def codegen_kernel(self, name=None):
31793179

31803180
return code.getvalue()
31813181

3182-
def _get_persistent_RBLOCK(self, rnumel):
3182+
@staticmethod
3183+
def _get_persistent_RBLOCK(rnumel):
31833184
rnumel = V.graph.sizevars.simplify(rnumel)
31843185
if isinstance(rnumel, (sympy.Integer, int)):
31853186
val = int(rnumel)
31863187
val = next_power_of_2(val)
31873188
else:
31883189
val = 128
31893190
while not V.graph.sizevars.statically_known_leq(rnumel, val):
3190-
assert val <= 16 * 1024, f"Failed to find static RBLOCK for {rnumel}"
3191+
if val > 16 * 1024:
3192+
raise ValueError(f"Failed to find static RBLOCK for {rnumel}")
31913193
val *= 2
31923194
return val
31933195

3196+
@staticmethod
3197+
def has_persistent_RBLOCK(rnumel):
3198+
try:
3199+
TritonKernel._get_persistent_RBLOCK(rnumel)
3200+
return True
3201+
except ValueError:
3202+
return False
3203+
31943204
def codegen_static_numels(self, code):
31953205
"""
31963206
We get a small speedup from hard coding numels if they are static.
@@ -3623,6 +3633,11 @@ def create_kernel_choices(
36233633
kernel_kwargs["override_persistent_reduction"] = True
36243634
kernel_kwargs["override_cooperative_reduction"] = False
36253635

3636+
if not TritonKernel.has_persistent_RBLOCK(kernel_features.reduction_numel):
3637+
# Cannot use persistent reduction with unknown dynamic rnumel
3638+
assert not kernel_kwargs.get("override_persistent_reduction")
3639+
kernel_kwargs["override_persistent_reduction"] = False
3640+
36263641
kernel_kwargs = V.choices.triton_kernel_kwargs(
36273642
kernel_type, kernel_features, kernel_args, kernel_kwargs
36283643
)

0 commit comments

Comments
 (0)