Skip to content

Commit cde82d2

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Add a multi_arch_kernel_binary option (#154413)
Summary: CUDA can support multi-arch with the fatbin format. Add this multi_arch_kernel_binary option, so the compiled model binary can run across different GPU archs. Differential Revision: [D75452094](https://our.internmc.facebook.com/intern/diff/D75452094) Pull Request resolved: #154413 Approved by: https://github.com/angelayi ghstack dependencies: #154412
1 parent 4d8f3d5 commit cde82d2

File tree

5 files changed

+139
-32
lines changed

5 files changed

+139
-32
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,43 @@ def forward(self, x, y):
156156
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
157157
)
158158

159+
@unittest.skipIf(
160+
IS_FBCODE,
161+
"toolchain doesn't support ptx to fatbin",
162+
)
163+
@skipIfRocm
164+
@skipIfXpu
165+
@common_utils.parametrize("embed_kernel_binary", [True, False])
166+
def test_simple_multi_arch(self, embed_kernel_binary):
167+
if self.device != GPU_TYPE:
168+
raise unittest.SkipTest("requires GPU_TYPE")
169+
170+
class Model(torch.nn.Module):
171+
def __init__(self) -> None:
172+
super().__init__()
173+
self.linear = torch.nn.Linear(10, 16)
174+
175+
def forward(self, x, y):
176+
return x + self.linear(y)
177+
178+
example_inputs = (
179+
torch.randn(10, 16, device=self.device),
180+
torch.randn(10, 10, device=self.device),
181+
)
182+
model = Model()
183+
with config.patch(
184+
{
185+
"aot_inductor.embed_kernel_binary": embed_kernel_binary,
186+
"aot_inductor.multi_arch_kernel_binary": True,
187+
}
188+
):
189+
self.check_model(model, example_inputs)
190+
if not embed_kernel_binary:
191+
_, code = run_and_get_cpp_code(
192+
AOTIRunnerUtil.compile, model, example_inputs
193+
)
194+
FileCheck().check(".fatbin").run(code)
195+
159196
def test_small_constant(self):
160197
class Model(torch.nn.Module):
161198
def __init__(self) -> None:

torch/_inductor/codecache.py

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,9 @@ def get_path(
407407
def get_hash(
408408
content: Union[str, bytes], extra: str = "", hash_type: str = "code"
409409
) -> str:
410-
if hash_type == "code":
410+
if hash_type in {"amdgcn", "code", "ptx"}:
411411
return code_hash(content, extra)
412-
if hash_type in ["cubin", "hsaco", "spv"]:
412+
if hash_type in {"cubin", "hsaco", "spv"}:
413413
return code_hash(repr(content))
414414
raise AssertionError(f"Unknown hash type {hash_type}")
415415

@@ -420,11 +420,13 @@ def write(
420420
extra: str = "",
421421
hash_type: str = "code",
422422
specified_dir: str = "",
423+
key: Optional[str] = None,
423424
) -> tuple[str, str]:
424-
# use striped content to compute hash so we don't end up with different
425-
# hashes just because the content begins/ends with different number of
426-
# spaces.
427-
key: str = get_hash(content.strip(), extra, hash_type)
425+
if key is None:
426+
# use striped content to compute hash so we don't end up with different
427+
# hashes just because the content begins/ends with different number of
428+
# spaces.
429+
key = get_hash(content.strip(), extra, hash_type)
428430
basename, _subdir, path = get_path(key, extension, specified_dir)
429431
if not os.path.exists(path):
430432
write_atomic(path, content, make_dirs=True)
@@ -1544,28 +1546,62 @@ class CudaKernelParamCache:
15441546
cache_clear = staticmethod(cache.clear)
15451547

15461548
@classmethod
1547-
def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> None:
1548-
_, path = write(
1549+
def set(
1550+
cls,
1551+
key: str,
1552+
params: dict[str, Optional[str]],
1553+
cubin: str,
1554+
bin_type: str,
1555+
asm: Optional[str] = None,
1556+
asm_type: Optional[str] = None,
1557+
) -> None:
1558+
basename = None
1559+
if config.aot_inductor.package_cpp_only:
1560+
assert config.triton.unique_kernel_names, (
1561+
"package_cpp_only requires triton kernel names to be unique"
1562+
)
1563+
assert params["mangled_name"], "Missing kernel name"
1564+
basename = params["mangled_name"]
1565+
1566+
_, bin_path = write(
15491567
cubin,
15501568
bin_type,
15511569
hash_type=bin_type,
15521570
specified_dir=split_aot_inductor_output_path(
15531571
config.aot_inductor.output_path
15541572
)[0],
1573+
key=basename,
15551574
)
1556-
if config.aot_inductor.package_cpp_only:
1557-
assert config.triton.unique_kernel_names, (
1558-
"package_cpp_only requires triton kernel names to be unique"
1575+
# Retrieve the basename again in case it is a generated hashcode
1576+
basename, _ = get_name_and_dir_from_output_file_path(bin_path)
1577+
1578+
if config.aot_inductor.multi_arch_kernel_binary:
1579+
assert bin_type == "cubin", (
1580+
"multi_arch_kernel_binary only supported in CUDA"
15591581
)
1560-
dir_name = os.path.dirname(path)
1561-
_, ext = os.path.splitext(path)
1562-
# Construct the new full path
1563-
new_path = os.path.join(dir_name, params["mangled_name"] + ext)
1564-
os.rename(path, new_path)
1565-
path = new_path
1582+
base_path, _ = os.path.splitext(bin_path)
1583+
bin_path = base_path + ".fatbin"
15661584

1567-
params[get_cpp_wrapper_cubin_path_name()] = path
1585+
asm_path: str = ""
1586+
if (
1587+
config.aot_inductor.multi_arch_kernel_binary
1588+
or config.aot_inductor.package_cpp_only
1589+
):
1590+
assert asm, "Missing kernel assembly code"
1591+
assert asm_type, "Missing kernel assembly type"
1592+
_, asm_path = write(
1593+
asm,
1594+
asm_type,
1595+
hash_type=asm_type,
1596+
specified_dir=split_aot_inductor_output_path(
1597+
config.aot_inductor.output_path
1598+
)[0],
1599+
# make sure asm file has the same basename
1600+
key=basename,
1601+
)
15681602

1603+
params[get_cpp_wrapper_cubin_path_name()] = bin_path
1604+
params["asm"] = asm_path
15691605
cls.cache[key] = params
15701606

15711607
@classmethod
@@ -2007,13 +2043,33 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
20072043
for entry in gpu_codecache.cache.values()
20082044
if entry.output_path.endswith(".o")
20092045
]
2046+
if gpu_kernels_o:
2047+
assert not config.aot_inductor.multi_arch_kernel_binary, (
2048+
"TODO: add multi_arch_kernel_binary support for cutlass kernels"
2049+
)
20102050

20112051
cubins_o = []
2012-
if config.aot_inductor.embed_kernel_binary:
2013-
# Embed cubin files into .so using objcopy
2014-
ld, objcopy = get_ld_and_objcopy(use_relative_path)
2015-
for kernel_name, value in CudaKernelParamCache.cache.items():
2016-
cubin_file = value[get_cpp_wrapper_cubin_path_name()]
2052+
asm_files = []
2053+
ld, objcopy = get_ld_and_objcopy(use_relative_path)
2054+
for kernel_name, value in CudaKernelParamCache.cache.items():
2055+
if asm_file := value["asm"]:
2056+
asm_files.append(asm_file)
2057+
2058+
cubin_file = value[get_cpp_wrapper_cubin_path_name()]
2059+
if config.aot_inductor.multi_arch_kernel_binary:
2060+
# Compile .ptx into .fatbin
2061+
archs = OrderedSet(
2062+
[cuda_env.get_cuda_arch(), "80", "86", "89", "90"]
2063+
)
2064+
cmd = f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file}"
2065+
for arch in archs:
2066+
cmd += f" -gencode arch=compute_{arch},code=compute_{arch}"
2067+
subprocess.run(
2068+
cmd.split(), capture_output=True, text=True, check=True
2069+
)
2070+
2071+
if config.aot_inductor.embed_kernel_binary:
2072+
# Embed cubin files into model.so using objcopy
20172073
cubins_o.append(
20182074
convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy)
20192075
)
@@ -2061,7 +2117,6 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
20612117

20622118
# If we only want to package the cpp, then we need to save the
20632119
# weights separately into a bin, and we also need to prevent compiling the so
2064-
20652120
if use_mmap_weights:
20662121
weight_file = str(
20672122
wrapper_path_operator.with_name(
@@ -2073,11 +2128,20 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
20732128
f_weights.write(struct.pack("q", magic_number))
20742129

20752130
generated_files.append(weight_file)
2131+
else:
2132+
# TODO: unify to alway use mmap_weights
2133+
generated_files.append(consts_o)
2134+
so_builder.save_src_to_cmake(cmake_path, consts_o)
2135+
2136+
if config.aot_inductor.multi_arch_kernel_binary:
2137+
# TODO: support multi-arch when package_cpp_only
2138+
pass
2139+
else:
2140+
obj_srcs = [*gpu_kernels_o, *cubins_o]
2141+
generated_files.extend(obj_srcs)
2142+
for obj in obj_srcs:
2143+
so_builder.save_src_to_cmake(cmake_path, obj)
20762144

2077-
obj_srcs = [consts_o, *gpu_kernels_o, *cubins_o]
2078-
generated_files.extend(obj_srcs)
2079-
for obj in obj_srcs:
2080-
so_builder.save_src_to_cmake(cmake_path, obj)
20812145
so_builder.save_link_cmd_to_cmake(cmake_path)
20822146
else:
20832147
so_builder.build()

torch/_inductor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,9 @@ class aot_inductor:
13301330
# Embed generated kernel binary files into model.so
13311331
embed_kernel_binary: bool = False
13321332

1333+
# Generate kernel binary files that support multiple archs
1334+
multi_arch_kernel_binary: bool = False
1335+
13331336
# Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
13341337
custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
13351338
# custom op libs that have implemented C shim wrappers

torch/_inductor/cpp_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ def convert_cubin_to_obj(
182182
obj_file = cubin_file + ".o"
183183
# Convert .cubin to .o
184184
cmd = f"{ld} -r -b binary -z noexecstack -o {obj_file} {cubin_file}"
185-
subprocess.run(cmd.split(), capture_output=True, text=True)
185+
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
186186
os.remove(cubin_file)
187187
# Rename .data to .rodata
188188
cmd = f"{objcopy} --rename-section .data=.rodata,alloc,load,readonly,data,contents {obj_file}"
189-
subprocess.run(cmd.split(), capture_output=True, text=True)
189+
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
190190
# By default objcopy will create *_start, *_size, *_end symbols using the full path
191191
# Rename to use the unique kernel name
192192
file_name = re.sub(r"[\W]", "_", cubin_file)
@@ -197,7 +197,7 @@ def convert_cubin_to_obj(
197197
+ f"--redefine-sym _binary_{file_name}_end=__{kernel_name}_end "
198198
+ obj_file
199199
)
200-
subprocess.run(cmd.split(), capture_output=True, text=True)
200+
subprocess.run(cmd.split(), capture_output=True, text=True, check=True)
201201
return obj_file
202202

203203

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,8 +1002,11 @@ def save_gpu_kernel(self, stream, launcher):
10021002

10031003
bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin")
10041004
binary = launcher.bin.asm[bin_type]
1005-
CudaKernelParamCache.set(key, params, binary, bin_type)
1005+
# Also store asm code which can be used for debugging and generating cpp package
1006+
asm_type = {"hip": "amdgcn", "cuda": "ptx"}.get(self.device_props.type, None)
1007+
asm = launcher.bin.asm.get(asm_type, None)
10061008

1009+
CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type)
10071010
self.cuda_kernel_saved = True
10081011

10091012
def coordinate_descent_tuning(self, launcher, *args, **kwargs):

0 commit comments

Comments
 (0)