@@ -407,9 +407,9 @@ def get_path(
407407def get_hash (
408408 content : Union [str , bytes ], extra : str = "" , hash_type : str = "code"
409409) -> str :
410- if hash_type == " code" :
410+ if hash_type in { "amdgcn" , " code", "ptx" } :
411411 return code_hash (content , extra )
412- if hash_type in [ "cubin" , "hsaco" , "spv" ] :
412+ if hash_type in { "cubin" , "hsaco" , "spv" } :
413413 return code_hash (repr (content ))
414414 raise AssertionError (f"Unknown hash type { hash_type } " )
415415
@@ -420,11 +420,13 @@ def write(
420420 extra : str = "" ,
421421 hash_type : str = "code" ,
422422 specified_dir : str = "" ,
423+ key : Optional [str ] = None ,
423424) -> tuple [str , str ]:
424- # use striped content to compute hash so we don't end up with different
425- # hashes just because the content begins/ends with different number of
426- # spaces.
427- key : str = get_hash (content .strip (), extra , hash_type )
425+ if key is None :
426+ # use striped content to compute hash so we don't end up with different
427+ # hashes just because the content begins/ends with different number of
428+ # spaces.
429+ key = get_hash (content .strip (), extra , hash_type )
428430 basename , _subdir , path = get_path (key , extension , specified_dir )
429431 if not os .path .exists (path ):
430432 write_atomic (path , content , make_dirs = True )
@@ -1544,28 +1546,62 @@ class CudaKernelParamCache:
15441546 cache_clear = staticmethod (cache .clear )
15451547
15461548 @classmethod
1547- def set (cls , key : str , params : dict [str , str ], cubin : str , bin_type : str ) -> None :
1548- _ , path = write (
1549+ def set (
1550+ cls ,
1551+ key : str ,
1552+ params : dict [str , Optional [str ]],
1553+ cubin : str ,
1554+ bin_type : str ,
1555+ asm : Optional [str ] = None ,
1556+ asm_type : Optional [str ] = None ,
1557+ ) -> None :
1558+ basename = None
1559+ if config .aot_inductor .package_cpp_only :
1560+ assert config .triton .unique_kernel_names , (
1561+ "package_cpp_only requires triton kernel names to be unique"
1562+ )
1563+ assert params ["mangled_name" ], "Missing kernel name"
1564+ basename = params ["mangled_name" ]
1565+
1566+ _ , bin_path = write (
15491567 cubin ,
15501568 bin_type ,
15511569 hash_type = bin_type ,
15521570 specified_dir = split_aot_inductor_output_path (
15531571 config .aot_inductor .output_path
15541572 )[0 ],
1573+ key = basename ,
15551574 )
1556- if config .aot_inductor .package_cpp_only :
1557- assert config .triton .unique_kernel_names , (
1558- "package_cpp_only requires triton kernel names to be unique"
1575+ # Retrieve the basename again in case it is a generated hashcode
1576+ basename , _ = get_name_and_dir_from_output_file_path (bin_path )
1577+
1578+ if config .aot_inductor .multi_arch_kernel_binary :
1579+ assert bin_type == "cubin" , (
1580+ "multi_arch_kernel_binary only supported in CUDA"
15591581 )
1560- dir_name = os .path .dirname (path )
1561- _ , ext = os .path .splitext (path )
1562- # Construct the new full path
1563- new_path = os .path .join (dir_name , params ["mangled_name" ] + ext )
1564- os .rename (path , new_path )
1565- path = new_path
1582+ base_path , _ = os .path .splitext (bin_path )
1583+ bin_path = base_path + ".fatbin"
15661584
1567- params [get_cpp_wrapper_cubin_path_name ()] = path
1585+ asm_path : str = ""
1586+ if (
1587+ config .aot_inductor .multi_arch_kernel_binary
1588+ or config .aot_inductor .package_cpp_only
1589+ ):
1590+ assert asm , "Missing kernel assembly code"
1591+ assert asm_type , "Missing kernel assembly type"
1592+ _ , asm_path = write (
1593+ asm ,
1594+ asm_type ,
1595+ hash_type = asm_type ,
1596+ specified_dir = split_aot_inductor_output_path (
1597+ config .aot_inductor .output_path
1598+ )[0 ],
1599+ # make sure asm file has the same basename
1600+ key = basename ,
1601+ )
15681602
1603+ params [get_cpp_wrapper_cubin_path_name ()] = bin_path
1604+ params ["asm" ] = asm_path
15691605 cls .cache [key ] = params
15701606
15711607 @classmethod
@@ -2007,13 +2043,33 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
20072043 for entry in gpu_codecache .cache .values ()
20082044 if entry .output_path .endswith (".o" )
20092045 ]
2046+ if gpu_kernels_o :
2047+ assert not config .aot_inductor .multi_arch_kernel_binary , (
2048+ "TODO: add multi_arch_kernel_binary support for cutlass kernels"
2049+ )
20102050
20112051 cubins_o = []
2012- if config .aot_inductor .embed_kernel_binary :
2013- # Embed cubin files into .so using objcopy
2014- ld , objcopy = get_ld_and_objcopy (use_relative_path )
2015- for kernel_name , value in CudaKernelParamCache .cache .items ():
2016- cubin_file = value [get_cpp_wrapper_cubin_path_name ()]
2052+ asm_files = []
2053+ ld , objcopy = get_ld_and_objcopy (use_relative_path )
2054+ for kernel_name , value in CudaKernelParamCache .cache .items ():
2055+ if asm_file := value ["asm" ]:
2056+ asm_files .append (asm_file )
2057+
2058+ cubin_file = value [get_cpp_wrapper_cubin_path_name ()]
2059+ if config .aot_inductor .multi_arch_kernel_binary :
2060+ # Compile .ptx into .fatbin
2061+ archs = OrderedSet (
2062+ [cuda_env .get_cuda_arch (), "80" , "86" , "89" , "90" ]
2063+ )
2064+ cmd = f"{ _cuda_compiler ()} -fatbin { asm_file } -o { cubin_file } "
2065+ for arch in archs :
2066+ cmd += f" -gencode arch=compute_{ arch } ,code=compute_{ arch } "
2067+ subprocess .run (
2068+ cmd .split (), capture_output = True , text = True , check = True
2069+ )
2070+
2071+ if config .aot_inductor .embed_kernel_binary :
2072+ # Embed cubin files into model.so using objcopy
20172073 cubins_o .append (
20182074 convert_cubin_to_obj (cubin_file , kernel_name , ld , objcopy )
20192075 )
@@ -2061,7 +2117,6 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
20612117
20622118 # If we only want to package the cpp, then we need to save the
20632119 # weights separately into a bin, and we also need to prevent compiling the so
2064-
20652120 if use_mmap_weights :
20662121 weight_file = str (
20672122 wrapper_path_operator .with_name (
@@ -2073,11 +2128,20 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
20732128 f_weights .write (struct .pack ("q" , magic_number ))
20742129
20752130 generated_files .append (weight_file )
2131+ else :
2132+ # TODO: unify to alway use mmap_weights
2133+ generated_files .append (consts_o )
2134+ so_builder .save_src_to_cmake (cmake_path , consts_o )
2135+
2136+ if config .aot_inductor .multi_arch_kernel_binary :
2137+ # TODO: support multi-arch when package_cpp_only
2138+ pass
2139+ else :
2140+ obj_srcs = [* gpu_kernels_o , * cubins_o ]
2141+ generated_files .extend (obj_srcs )
2142+ for obj in obj_srcs :
2143+ so_builder .save_src_to_cmake (cmake_path , obj )
20762144
2077- obj_srcs = [consts_o , * gpu_kernels_o , * cubins_o ]
2078- generated_files .extend (obj_srcs )
2079- for obj in obj_srcs :
2080- so_builder .save_src_to_cmake (cmake_path , obj )
20812145 so_builder .save_link_cmd_to_cmake (cmake_path )
20822146 else :
20832147 so_builder .build ()
0 commit comments