Skip to content

Commit cca9969

Browse files
maxyanghupytorchmergebot
authored andcommitted
dedup Externchoice
1 parent c83a47f commit cca9969

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

test/inductor/test_external_callables.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Owner(s): ["module: inductor"]
22
import torch
3+
import unittest
34

45
from torch._inductor import config
5-
66
from torch._inductor.test_case import run_tests, TestCase
77

8+
from torch.testing._internal.common_cuda import TEST_CUDA
9+
810

911
class MatMulModule(torch.nn.Module):
1012
def __init__(self):
@@ -48,6 +50,7 @@ def test_matmul_cpu(self):
4850
msg=f"torch.compile(..., external_matmul = {matmul_cpu}) failed",
4951
)
5052

53+
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
5154
def test_matmul_cuda(self):
5255
device = torch.device("cuda")
5356
x = (torch.eye(128, 128) * 2).to(device=device)

torch/_inductor/kernel/mm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@
122122
""",
123123
)
124124

125+
# prevent duplication registration of extern functions
126+
@functools.lru_cache(None)
127+
def lazy_register_extern_choice(fn):
128+
return ExternKernelChoice(fn)
129+
125130
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")
126131

127132
aten_addmm = ExternKernelChoice(
@@ -248,7 +253,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
248253
from ..config import external_matmul
249254

250255
for k in external_matmul:
251-
choices.append(ExternKernelChoice(k).bind((mat1, mat2), layout))
256+
choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout))
252257

253258
try:
254259
return autotune_select_algorithm(name, choices, [mat1, mat2], layout)

0 commit comments

Comments
 (0)