-
Notifications
You must be signed in to change notification settings - Fork 27.5k
Add at::cpu namespace of functions for structured kernels #49505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f2e9581
e4d9115
02733fb
7d6170b
9dede96
76cba79
9576510
37a41e8
2e37a4a
01b3a91
aadd12c
f5a257a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| // ${generated_comment} | ||
|
|
||
| // NB: The implementing C++ file is RegisterDispatchKey.cpp | ||
|
|
||
| // TODO: tighten this include | ||
| #include <ATen/Functions.h> | ||
|
|
||
| namespace at { | ||
| namespace ${dispatch_namespace} { | ||
|
|
||
| ${dispatch_declarations} | ||
|
|
||
| } // namespace ${dispatch_namespace} | ||
| } // namespace at |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -222,13 +222,13 @@ def __call__(self, f: NativeFunction) -> Optional[str]: | |
| # API without having to disambiguate which overload you want | ||
| # (as would be the case if you directly registered native:: | ||
| # functions). | ||
| # - The tertiary function of this file is to generate *static* | ||
| # cpp API bindings which can be used to bypass dispatcher | ||
| # directly to kernels, but with user-friendly cpp-style API | ||
| @dataclass(frozen=True) | ||
| class RegisterDispatchKey: | ||
| dispatch_key: str | ||
|
|
||
| # TODO: Give more precise type Union[Literal[Target.DEFINITION, | ||
| # Target.REGISTRATION]]; requires Literal from typing_extensions | ||
| # which we don't have a dep for yet. | ||
| target: Target | ||
|
|
||
| # Selector object to determine which operators to generate | ||
|
|
@@ -238,9 +238,6 @@ class RegisterDispatchKey: | |
| # Whether or not we are actually code-genning for ROCm | ||
| rocm: bool | ||
|
|
||
| def __post_init__(self) -> None: | ||
| assert self.target is not Target.DECLARATION | ||
|
|
||
| @method_with_native_function | ||
| def __call__(self, f: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]: | ||
| if isinstance(f, StructuredNativeFunctions): | ||
|
|
@@ -386,7 +383,6 @@ def gen_structured(self, g: StructuredNativeFunctions) -> List[str]: | |
| # you edit this, you may need to also edit gen_unstructured. | ||
| @with_native_function | ||
| def gen_one(f: NativeFunction) -> Optional[str]: | ||
| assert self.target is not Target.DECLARATION | ||
| assert not f.manual_kernel_registration | ||
|
|
||
| # TODO: put this into StructuredNativeFunctions itself | ||
|
|
@@ -402,9 +398,20 @@ def gen_one(f: NativeFunction) -> Optional[str]: | |
| return None | ||
|
|
||
| k = f.func.kind() | ||
| sig = NativeSignature.from_schema(f.func) | ||
| sig = NativeSignature(f.func, prefix="wrapper_") | ||
|
|
||
| # Unconditionally generate function version, never fallback binding. | ||
| # Some extra massaging would then be necessary in a hypothetical | ||
| # CPUTensor class | ||
| cpp_sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False) | ||
| # For now, don't generate faithful signature for simplicity | ||
| cpp_sig = cpp_sig_group.signature | ||
|
|
||
| if self.target is Target.DEFINITION: | ||
| if self.target is Target.DECLARATION: | ||
| # namespace is handled by template | ||
| return f"TORCH_API {cpp_sig.decl()};\n" | ||
|
|
||
| elif self.target is Target.DEFINITION: | ||
| if self.dispatch_key == 'Meta': | ||
| class_name = f"structured_{meta.name(g)}_meta_{k.name}" | ||
| parent_class = f"at::meta::{meta.name(g)}" | ||
|
|
@@ -452,6 +459,12 @@ def gen_one(f: NativeFunction) -> Optional[str]: | |
| }} | ||
|
|
||
| }} // anonymous namespace | ||
|
|
||
| namespace {self.dispatch_key.lower()} {{ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not a big deal, but the individual Actually, isn't this generating at::<dispatch_key> functions for every dispatch key, and then only providing headers for the specific keys we want (cpu/cuda)? Shouldn't we keep those two in sync? ( only bother providing dispatcher-skipping implementations for dispatch keys that we provide headers for)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. The alternative is to split up the implementations in codegen. I'm ambivalent about this; so if someone feels strongly I'll swap it around.
Technically yes, but in reality only CPU and CUDA are supported by structured, so there isn't actually any wastage.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Thought it was worth calling out, but I'm ambivalent as well :)
Ah right, yeah |
||
| {cpp_sig.defn()} {{ | ||
| return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))}); | ||
| }} | ||
| }} // namespace {self.dispatch_key.lower()} | ||
| """ | ||
|
|
||
| elif self.target is Target.REGISTRATION: | ||
|
|
@@ -468,9 +481,6 @@ def gen_one(f: NativeFunction) -> Optional[str]: | |
|
|
||
| @method_with_native_function | ||
| def gen_unstructured(self, f: NativeFunction) -> Optional[str]: | ||
| # for mypy type refinement; would be fixed by TODO on target | ||
| assert self.target is not Target.DECLARATION | ||
|
|
||
| if self.dispatch_key not in f.dispatch: | ||
| return None | ||
| if f.manual_kernel_registration: | ||
|
|
@@ -484,7 +494,9 @@ def gen_unstructured(self, f: NativeFunction) -> Optional[str]: | |
| args = native.arguments(f.func) | ||
| args_str = ', '.join(a.defn() for a in args) | ||
|
|
||
| if self.target is Target.DEFINITION: | ||
| if self.target is Target.DECLARATION: | ||
| return '' | ||
| elif self.target is Target.DEFINITION: | ||
| impl_name = f"at::native::{f.dispatch[self.dispatch_key]}" | ||
|
|
||
| args_exprs_str = ', '.join(a.name for a in args) | ||
|
|
@@ -768,7 +780,7 @@ def __call__(self, f: NativeFunction) -> Optional[str]: | |
| return None | ||
|
|
||
| name = native.name(f.func) | ||
| native_sig = NativeSignature.from_schema(f.func) | ||
| native_sig = NativeSignature(f.func) | ||
|
|
||
| if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()): | ||
| return None | ||
|
|
@@ -1284,15 +1296,19 @@ def make_file_manager(install_dir: str) -> FileManager: | |
| # kernels | ||
| "Meta", | ||
| ] | ||
| # Only a limited set of dispatch keys get CPUFunctions.h headers generated | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/CPUFunctions.h/{dispatch key}Functions.h/ or whatever |
||
| # for them; this is the set | ||
| functions_keys = { | ||
| "CPU", | ||
| "CUDA", | ||
| } | ||
| if options.backend_whitelist: | ||
| dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or k in options.backend_whitelist] | ||
|
|
||
| for dispatch_key in dispatch_keys: | ||
| cpp_template = 'RegisterDispatchKey.cpp' | ||
|
|
||
| fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm | ||
|
|
||
| fm.write_with_template(f'Register{dispatch_key}.cpp', cpp_template, lambda: { | ||
| fm.write_with_template(f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { | ||
| 'extra_cuda_headers': extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else '', | ||
| 'legacy_th_headers': | ||
| '#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == "CPU" else | ||
|
|
@@ -1308,6 +1324,16 @@ def make_file_manager(install_dir: str) -> FileManager: | |
| grouped_native_functions | ||
| )), | ||
| }) | ||
|
|
||
| if dispatch_key in functions_keys: | ||
| fm.write_with_template(f'{dispatch_key}Functions.h', 'DispatchKeyFunctions.h', lambda: { | ||
| 'dispatch_namespace': dispatch_key.lower(), | ||
| 'dispatch_declarations': list(concatMap( | ||
| RegisterDispatchKey(dispatch_key, Target.DECLARATION, selector, rocm=options.rocm), | ||
| grouped_native_functions | ||
| )), | ||
| }) | ||
|
|
||
| del fm | ||
|
|
||
| # BackendSelect is generated specially | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any subtle issues involved in also generating faithful sig versions that might be called out here? I can imagine somebody hitting a future use case that wants them (say, binding straight from python to backend-specific functions) and getting caught on something nonobvious, might be worth calling out any such
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I think we probably should generate faithful version too. I just didn't need it, so I didn't put in the logic for it.