-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[prototype only] Add cutlass as an alternative backend to inductor #106607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106607
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New FailuresAs of commit 8cbe612 with merge base c0b8b7b ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
test/inductor/test_max_autotune.py
Outdated
|
|
||
| import torch | ||
|
|
||
| torch.cuda.init() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? Does it have to be in global scope?
torch/_inductor/autotune_process.py
Outdated
| ) -> Callable[[], None]: | ||
| raise NotImplementedError() | ||
|
|
||
| def profiling_keyword(self) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
document this? not clear from name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is useless now, will remove it.
torch/_inductor/codecache.py
Outdated
|
|
||
|
|
||
| def _cutlass_include_paths() -> List[str]: | ||
| _CUTLASS_PATH = os.path.join( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local variables should be lowercase
| def _dlclose(self): | ||
| f_dlclose = None | ||
|
|
||
| if is_windows(): | ||
| f_dlclose = ctypes.windll.kernel32.FreeLibrary | ||
| elif is_linux(): | ||
| syms = ctypes.CDLL(None) | ||
| if not hasattr(syms, "dlclose"): | ||
| # Apline Linux | ||
| syms = ctypes.CDLL("libc.so") | ||
|
|
||
| if hasattr(syms, "dlclose"): | ||
| f_dlclose = syms.dlclose | ||
|
|
||
| if f_dlclose is not None: | ||
| f_dlclose.argtypes = [ctypes.c_void_p] | ||
| f_dlclose(self.DLL._handle) | ||
| else: | ||
| logging.warning( | ||
| "dll unloading function was not found, library may not be unloaded properly!" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this copied from somewhere? (I am assuming you aren't testing on windows.) If so add a comment explaining where it came from (or reuse it if it is within our codebase).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I copied them from cpp_extension and AITemplate code. After a second thought I think I'll just remove "windows" logics since they're not tested. People could work on more proper window support later.
torch/_inductor/codegen/common.py
Outdated
| outer = inplaced.other_names[-1] | ||
| inner = inplaced.inner_name | ||
| dtype = buffer_types[outer] | ||
| dtype = buffer_types.get(outer, V.graph.get_dtype(outer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might be able to just call V.graph.get_dtype(outer) directly and delete buffer_types from this function. I think get_dtype already handles constants (or we could easily make it handle them).
| return f"[{', '.join(sizes)}]" | ||
|
|
||
| def call_kernel(self, name: str): | ||
| def call_kernel(self, name: str, node: IRNode = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make node required and update the users. We can drop name since there is node.get_name().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually name != node.get_name(). name here is generated by self.define_kernel(src_code, node_schedule).
|
|
||
| # finalize must be called after adding epilogue above | ||
| src_code = partial_code.finalize() | ||
| src_code = partial_code if isinstance(partial_code, str) else partial_code.finalize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't be needed, see #105987
| if buffer.get_workspace_size() > 0: | ||
| workspace_size = tuple(buffer.get_workspace_size()) | ||
| workspace_stride = tuple(1) | ||
| res.append( | ||
| f"{buffer.get_name()}_workspace = empty_strided(" | ||
| f"{self.codegen_shape_tuple(workspace_size)}, " | ||
| f"{self.codegen_shape_tuple(workspace_stride)}, " | ||
| f"device='{device.type}', dtype={torch.uint8})" | ||
| ) | ||
| return res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to make this work with dynamic shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is a TODO, won't be included in this PR.
| def make_buffer_free(self, buffer) -> List[str]: | ||
| res = [f"del {buffer.get_name()}"] | ||
| if buffer.get_workspace_size() > 0: | ||
| res.append(f"del {buffer.get_name()}_workspace") | ||
| return res | ||
|
|
||
| def make_buffer_reuse(self, old, new): | ||
| def make_buffer_reuse(self, old, new) -> List[str]: | ||
| assert old.get_dtype() == new.get_dtype() | ||
| del_line = "" | ||
| res = [] | ||
| del_line0, del_line1 = "" | ||
| if old.get_name() not in V.graph.get_output_names(): | ||
| del_line = f"; {self.make_buffer_free(old)}" | ||
| if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): | ||
| return f"{self.declare}{new.get_name()} = {old.get_name()}{del_line} {self.comment} reuse" | ||
| del_lines = self.make_buffer_free(old) | ||
| del_line0 = "" if len(del_lines) < 1 else del_lines[0] | ||
| del_line1 = "" if len(del_lines) < 2 else del_lines[1] | ||
| if ( | ||
| old.get_size() == new.get_size() | ||
| and old.get_stride() == new.get_stride() | ||
| and old.get_workspace_size() == new.get_workspace_size() | ||
| ): | ||
| res.append(f"{self.declare}{new.get_name()} = {old.get_name()}{del_line0} {self.comment} reuse") | ||
| if new.get_workspace_size() > 0: | ||
| res.append(f"{self.declare}{new.get_name()}_workspace = {old.get_name()}_workspace{del_line1} {self.comment} reuse") | ||
| return res | ||
|
|
||
| return ( | ||
| res.append( | ||
| f"{self.declare}{new.get_name()} = {self.namespace}as_strided({old.get_name()}, " | ||
| f"{self.codegen_shape_tuple(new.get_size())}, " | ||
| f"{self.codegen_shape_tuple(new.get_stride())}){del_line} {self.comment} reuse" | ||
| f"{self.codegen_shape_tuple(new.get_stride())}){del_line0} {self.comment} reuse" | ||
| ) | ||
| if new.get_workspace_size() > 0: | ||
| res.append(f"{self.declare}{new.get_name()}_workspace = {old.get_name()}_workspace{del_line1} {self.comment} reuse") | ||
| return res | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect it would be simpler to allocate workspaces separately from regular buffers. Doing in this way means workspaces can only be reused if the main buffer can be reused (which often not the case).
| VarRanges = Dict[sympy.Expr, sympy.Expr] | ||
|
|
||
|
|
||
| def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can unify the way we benchmark triton and cutlass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TritonTemplateKernel and CUDATemplateKernel benchmark is unified at torch/_inductor/autotune_process.py:BenchmarkRequest::benchmark(). I introduced this method because the existing profiling method measures CPU-side kernel launch overhead and is not accurate. However this method could also be fragile as it needs to filter out all irrelevant kernels. I'll add more comments.
aakhundov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ipiszy thanks for the e2e prototype! I've looked through some of the code. Will look further and likely add more comments.
test/inductor/test_max_autotune.py
Outdated
| { | ||
| "max_autotune": True, | ||
| "autotune_in_subproc": True, | ||
| "max_autotune_gemm_backends": max_autotune_gemm_backends, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like max_autotune_gemm_backends doesn't exist in here. Add a @parameterized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I'll just remove it.
torch/_inductor/autotune_process.py
Outdated
| + f"create tensor {create_tensor_elapse}, bench {bench_elapse}, " | ||
| + f"collected time {out}" | ||
| ) | ||
| self.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be called by the caller of the benchmark function? Request closing itself looks somewhat asymmetric. Or is the semantics that a request may be benchmarked only once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll rename it to cleanup_run_fn() so that it's symmetric with make_run_fn().
torch/_inductor/autotune_process.py
Outdated
| self.workspace_size == 0 | ||
| ), "Autotune cache needs to be updated to support non-zero workspace_size!" | ||
|
|
||
| workspace = torch.empty(self.workspace_size, dtype=torch.uint8, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the workspace Tensor going to be garbage-collected at the end of this function, as we're only using its .data_ptr() in the functools.partial(...) below? Perhaps better set it as an attribute of CUDABenchmarkRequest? Then you can also get rid of the workspace_size attribute, as it's only used once during the workspace computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I think I still need to keep "workspace_size" as it needs to be propagated to the codegen stage for memory planning purpose. Will make workspace a member variable.
torch/_inductor/autotune_process.py
Outdated
| ) | ||
|
|
||
| def profiling_keyword(self) -> str: | ||
| return "cutlass" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd suggest either using "cuda" here or CUTLASS in the class name.
| res = "" | ||
| if dst_file_suffix == "o": | ||
| res = ( | ||
| _cuda_compiler() + " " + " ".join(options) + f" -c -o {dst_file} {src_file}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: put the parts on separate lines for clarity? Also below.
| # CUDA arch to use for CUDA template kernel compilation. | ||
| # Available options: "70", "75", "80", "90" | ||
| # When arch is None, the Inductor tries to detect CUDA arch by querying | ||
| # CUDA Python lib or "nvidia-smi" commandline. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super-nit: "nvidia-smi" command.
torch/_inductor/config.py
Outdated
| # e.g. "11.4", "12.1", etc. | ||
| # When version is None, the Inductor tries to detech CUDA version by | ||
| # querying CUDA Python lib for "nvidia-smi" commandline. | ||
| version = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to compile with an earlier CUDA version than installed? If not, should we just auto-detect it from the system in all cases?
torch/_inductor/config.py
Outdated
|
|
||
| enable_ptxas_info = False | ||
|
|
||
| enable_debug_info = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to make this False by default, eventually.
|
|
||
|
|
||
| def get_layout(self): | ||
| return self.data.get_layout() | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: extra blank lines above and below. Linter will likely complain.
| def get_workspace_size(self): | ||
| return 0 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this to TemplateBuffer, as that seems to be the lowest common subclass where this is relevant? Or is it necessary for arbitrary Buffer instances to be able to return this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it depends on the algorithm. In theory triton codegen code can also require extra workspace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think, you're invoking this method on abstract Buffer in some context, I remember. So ignore the comment.
| ) | ||
|
|
||
|
|
||
| def use_cutlass_template(layout): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very similar to use_triton_template. Maybe extract common part into a helper?
| is not None | ||
| else "percentiles" | ||
| ) | ||
| return triton_do_bench |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a leftover?
| n_warmup = max(1, int(warmup / estimate_ms)) | ||
| n_repeat = max(1, int(rep / estimate_ms)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, for 0.1ms kernel this will render 250 warmup and 1000 profiling iterations. Do we really need that many?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it doesn't matter? Is it an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Purely from e2e time perspective. By decreasing warmup to, say, 10 iterations, you'll shave off ~1/5 of the total kernel runtime during profiling. Although maybe the total kernel runtime is not a bottleneck, so it doesn't matter indeed.
| def _should_keep(key: str) -> bool: | ||
| if key.startswith("cuda"): | ||
| return False | ||
| if key == "Context Sync": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recently, we've also noticed other profiler events like:
"Event Sync",
"Stream Wait Event",
"Stream Sync",
"Unknown Sync",
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add them.
| if len(choices) == 1: | ||
| return choices[0].output_node() | ||
| # if len(choices) == 1: | ||
| # return choices[0].output_node() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious what's the reason of commenting this out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to make sure that workspace_size is fetched via the benchmark process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I guess, this means that even one-option autotuning (like in the ATen-only cases) will run the kernel / invoke the subprocess etc.? I'm wondering if this can degrade compilation time. Likely the cache will be hit on subsequent attempts, though.
| ) | ||
| if buffer.get_workspace_size() > 0: | ||
| workspace_size = tuple(buffer.get_workspace_size()) | ||
| workspace_stride = tuple(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be tuple([1])? Throws an error otherwise (Python 3.11):
>>> tuple(1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: 'int' object is not iterable
More importantly, this likely shows that the code path with workspace size > 1 is not verified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll either remove these logics or add more unittests. Default cutlass gemm kernels only have workspze_size == 0 so in E2E test this path is not exercised. From Haicheng only group_gemm and parallel split_k need this.
| return f"auto {name} = outputs[{output_idx}];" | ||
| else: | ||
| self.outputs_need_copy.add(name) | ||
| [self.outputs_need_copy.add(name)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[...] is not required here, but likely required on the L1295 above?
| ) | ||
| if ( | ||
| inp_expanded.get_stride()[0] == 0 | ||
| use_aten_gemm_kernels() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
| try: | ||
| from cuda import cuda | ||
|
|
||
| _assert_cuda(cuda.cuInit(0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can initializing CUDA outside torch lead to undesirable effects?
|
|
||
|
|
||
| def dtype_match(torch_dtype, cutlass_dtype) -> bool: | ||
| if torch_dtype == torch.float or torch_dtype == torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: use in? Here and below.
| cutlass_dtype == cutlass_lib.DataType.f32 | ||
| or cutlass_dtype == cutlass_lib.DataType.tf32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it safe to treat both f32 and tf32 as equivalent to torch.float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf32 is the NV way to process fp32 gemms using tensor core. In cutlass it's not used in other places. So I think it's fine.
| for i in range(len(call_args)): | ||
| if V.graph.is_unspec_arg(call_args[i]): | ||
| call_args[i] = call_args[i] + ".item()" | ||
| call_args[i] = f"c_void_p({call_args[i]}.data_ptr())" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will break with the unspec args getting the .item() call above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be..
| _DTYPE_TO_CUTLASS = { | ||
| torch.float32: "float", | ||
| torch.float64: "double", | ||
| torch.float16: "cutlass::half_t", | ||
| torch.int32: "int", | ||
| torch.int8: "int8_t", | ||
| torch.uint8: "uint8_t", | ||
| torch.bool: "bool", | ||
| torch.bfloat16: "cutlass::bfloat16_t", | ||
| } | ||
|
|
||
| def dtype(self, node: IRNode) -> str: | ||
| return self._DTYPE_TO_CUTLASS.get(node.get_dtype()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to move this to the CutlassTemplate below?
| #endif | ||
| #endif | ||
| using bfloat16 = nv_bfloat16; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this require #include <cuda_bf16.h>?
| if filter_res is not None: | ||
| res.append(filter_res) | ||
| print(f"Got cutlass configs: {len(res)=}") | ||
| return [res[0]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, this is to speed up the profiling? But we shouldn't forget to return res here eventually. Also maybe for more thorough testing.
|
I've finished basic A100 / H100 gemm, gemm+bias, gemm+matrix tests and fixed some issues. Will work on separate PRs to address comments, thanks! |
80f0c90 to
53e25a8
Compare
ipiszy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jansel @aakhundov for the review! I replied some comments, working on PRs now.
torch/_inductor/autotune_process.py
Outdated
| + f"create tensor {create_tensor_elapse}, bench {bench_elapse}, " | ||
| + f"collected time {out}" | ||
| ) | ||
| self.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll rename it to cleanup_run_fn() so that it's symmetric with make_run_fn().
torch/_inductor/autotune_process.py
Outdated
| self.workspace_size == 0 | ||
| ), "Autotune cache needs to be updated to support non-zero workspace_size!" | ||
|
|
||
| workspace = torch.empty(self.workspace_size, dtype=torch.uint8, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I think I still need to keep "workspace_size" as it needs to be propagated to the codegen stage for memory planning purpose. Will make workspace a member variable.
| raise exc.CppCompileError(cmd, error.output) from error | ||
| cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path) | ||
|
|
||
| return (DLLWrapper(cls.cache[key].output_path), key, input_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm worried about naming conflicts if multiple DLLs have the same symbol name. e.g. In benchmark, all gemm_kernels have the same name to make sure code cache can be reused.
| is_load_uint8_as_float: bool = False | ||
|
|
||
|
|
||
| @functools.lru_cache(None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@functools.cache is only available from python3.9. In internal env lots of hosts are still using python3.8 so I'd like to make sure it won't break with python3.8.
| // When workspace_size is not a nullptr, populates requested workspace_size and returns. | ||
| // Otherwise, compuates the Gemm kernel using the given workspace ptr. | ||
| extern "C" { | ||
| {{kernel.def_kernel(inputs=[X, W, Bias], outputs=[Y], names_str="X, W, Bias, Y", input_reorder=input_reorder)}} { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X, W, Bias are IRNodes. names_str is used to represent names of these nodes in generated code.
I feel passing IRNodes directly here would make code a bit easier to reason about. Besides, GemmTemplate also uses these IRNodes to select cutlass configs and generate templates. It looks a bit confusing to keep the node / name mapping inside the GemmTemplate.
| return True # skip checks for compatible tiling | ||
| # Only allow fusion for TritonTemplates for now. | ||
| # Fusion for CUDATemplates are not supported. | ||
| return isinstance(node1.node, TritonTemplateBuffer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isinstance(None, TritonTemplateBuffer) should return False so fusion won't happen. It should be fine?
| if buffer.get_workspace_size() > 0: | ||
| workspace_size = tuple(buffer.get_workspace_size()) | ||
| workspace_stride = tuple(1) | ||
| res.append( | ||
| f"{buffer.get_name()}_workspace = empty_strided(" | ||
| f"{self.codegen_shape_tuple(workspace_size)}, " | ||
| f"{self.codegen_shape_tuple(workspace_stride)}, " | ||
| f"device='{device.type}', dtype={torch.uint8})" | ||
| ) | ||
| return res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is a TODO, won't be included in this PR.
| def _should_keep(key: str) -> bool: | ||
| if key.startswith("cuda"): | ||
| return False | ||
| if key == "Context Sync": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add them.
| VarRanges = Dict[sympy.Expr, sympy.Expr] | ||
|
|
||
|
|
||
| def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TritonTemplateKernel and CUDATemplateKernel benchmark is unified at torch/_inductor/autotune_process.py:BenchmarkRequest::benchmark(). I introduced this method because the existing profiling method measures CPU-side kernel launch overhead and is not accurate. However this method could also be fragile as it needs to filter out all irrelevant kernels. I'll add more comments.
| ) | ||
| if buffer.get_workspace_size() > 0: | ||
| workspace_size = tuple(buffer.get_workspace_size()) | ||
| workspace_stride = tuple(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll either remove these logics or add more unittests. Default cutlass gemm kernels only have workspze_size == 0 so in E2E test this path is not exercised. From Haicheng only group_gemm and parallel split_k need this.
|
Hi @jansel @aakhundov , I created new PR stack #107933, ptal, thanks! |
|
Related on adding |
NOTE: This PR is a prototype to collect early feedback. I'll follow up on splitting it into small PRs and add proper unittests.
This PR adds cutlass support for simple gemm cases (matmul and linear), without arbitrary epilogue fusions. More support for epilogue fusion will be implemented on top of cutlass epilogue visitor, which is to be released in cutlass 3.2.
Key changes:
API:
Introduce "Cutlass" as an alternative backend in
max_autotune_gemm_backendsin torch/_inductor/config.py.This option is exposed via
torch.compile()APIoptionsarg.CUDA env:
Add a file torch/_inductor/codegen/cuda/cuda_env.py to get cuda version / arch. Not sure whether there are better ways to do this in Pytorch.
CUDA / CutlassGemm template:
TritonTemplateBufferandCUDATemplateBuffersubclasses. CUDATemplateBuffer keeps workspace_size which is required by cutlass kernels. Remembering the workspace_size is useful for memory planning.CUDATemplateandCutlassGemmTemplatefor code rendering.CUDAKernelandCUDATemplateKernelfor code rendering.CUDAScheduling, which simply inherit fromTritonScheduling, since it's only used by template codegen instead of more general node fusions. Accordingly, rely on "triton.py" forcodegen_template()andcan_fuse()and add simpleif / elsebranches for CUDAScheduling. This looks a bit hacky, please suggest better ways.CUDA codecache / autotuning
CUDACodeCachefor CUDA compilation. Currently it relies on nvcc. More follow-ups are needed, e.g. nvrtc, add support for non-linux platform, and add support for internal production environment.TritonBenchmarkRequestandCUDABenchmarkRequestsubclasses to handle benchmarking.Profiling util
torch.profile()to get accurate timing info. However, one caveat is that it may excludevectorized_elementwise_kernelkernel incorrectly which is used to clear L2 cache.Tests:
Static shape only.
Things to fix in this PR:
Things to fix in follow-up PRs:
Things to add:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov