Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ genrule(
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterDefaultBackend.cpp",
"aten/src/ATen/RegisterSchema.cpp",
"aten/src/ATen/CPUFunctions.h",
"aten/src/ATen/CUDAFunctions.h",
"aten/src/ATen/Functions.h",
"aten/src/ATen/Functions.cpp",
"aten/src/ATen/NativeFunctions.h",
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/templates/DispatchKeyFunctions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// ${generated_comment}

// NB: The implementing C++ file is RegisterDispatchKey.cpp

// TODO: tighten this include
#include <ATen/Functions.h>

namespace at {
namespace ${dispatch_namespace} {

${dispatch_declarations}

} // namespace ${dispatch_namespace}
} // namespace at
8 changes: 3 additions & 5 deletions tools/codegen/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,10 @@ class NativeSignature:
# The schema this signature is derived from
func: FunctionSchema

prefix: str = ""

def name(self) -> str:
return native.name(self.func)
return self.prefix + native.name(self.func)

def defn(self, name: Optional[str] = None) -> str:
args_str = ', '.join(a.defn() for a in self.arguments())
Expand All @@ -265,9 +267,5 @@ def arguments(self) -> List[Binding]:
def dispatcher_exprs(self) -> List[Expr]:
return translate.translate(self.arguments(), dispatcher.arguments(self.func), method=False)

@staticmethod
def from_schema(func: FunctionSchema) -> 'NativeSignature':
return NativeSignature(func)

# Functions only, no types
from tools.codegen.api import cpp, dispatcher, native, translate
60 changes: 43 additions & 17 deletions tools/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,13 @@ def __call__(self, f: NativeFunction) -> Optional[str]:
# API without having to disambiguate which overload you want
# (as would be the case if you directly registered native::
# functions).
# - The tertiary function of this file is to generate *static*
# cpp API bindings which can be used to bypass dispatcher
# directly to kernels, but with user-friendly cpp-style API
@dataclass(frozen=True)
class RegisterDispatchKey:
dispatch_key: str

# TODO: Give more precise type Union[Literal[Target.DEFINITION,
# Target.REGISTRATION]]; requires Literal from typing_extensions
# which we don't have a dep for yet.
target: Target

# Selector object to determine which operators to generate
Expand All @@ -238,9 +238,6 @@ class RegisterDispatchKey:
# Whether or not we are actually code-genning for ROCm
rocm: bool

def __post_init__(self) -> None:
assert self.target is not Target.DECLARATION

@method_with_native_function
def __call__(self, f: Union[StructuredNativeFunctions, NativeFunction]) -> List[str]:
if isinstance(f, StructuredNativeFunctions):
Expand Down Expand Up @@ -386,7 +383,6 @@ def gen_structured(self, g: StructuredNativeFunctions) -> List[str]:
# you edit this, you may need to also edit gen_unstructured.
@with_native_function
def gen_one(f: NativeFunction) -> Optional[str]:
assert self.target is not Target.DECLARATION
assert not f.manual_kernel_registration

# TODO: put this into StructuredNativeFunctions itself
Expand All @@ -402,9 +398,20 @@ def gen_one(f: NativeFunction) -> Optional[str]:
return None

k = f.func.kind()
sig = NativeSignature.from_schema(f.func)
sig = NativeSignature(f.func, prefix="wrapper_")

# Unconditionally generate function version, never fallback binding.
# Some extra massaging would then be necessary in a hypothetical
# CPUTensor class
cpp_sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False)
# For now, don't generate faithful signature for simplicity
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any subtle issues involved in also generating faithful sig versions that might be called out here? I can imagine somebody hitting a future use case that wants them (say, binding straight from python to backend-specific functions) and getting caught on something nonobvious, might be worth calling out any such

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think we probably should generate faithful version too. I just didn't need it, so I didn't put in the logic for it.

cpp_sig = cpp_sig_group.signature

if self.target is Target.DEFINITION:
if self.target is Target.DECLARATION:
# namespace is handled by template
return f"TORCH_API {cpp_sig.decl()};\n"

elif self.target is Target.DEFINITION:
if self.dispatch_key == 'Meta':
class_name = f"structured_{meta.name(g)}_meta_{k.name}"
parent_class = f"at::meta::{meta.name(g)}"
Expand Down Expand Up @@ -452,6 +459,12 @@ def gen_one(f: NativeFunction) -> Optional[str]:
}}

}} // anonymous namespace

namespace {self.dispatch_key.lower()} {{
Copy link
Copy Markdown
Collaborator

@bdhirsh bdhirsh Dec 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not a big deal, but the individual namespace cpu {...} for every op will probably make these files a few thousand lines longer, vs. grouping them all together in one namespace block.

Actually, isn't this generating at::<dispatch_key> functions for every dispatch key, and then only providing headers for the specific keys we want (cpu/cuda)? Shouldn't we keep those two in sync? ( only bother providing dispatcher-skipping implementations for dispatch keys that we provide headers for)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. The alternative is to split up the implementations in codegen. I'm ambivalent about this; so if someone feels strongly I'll swap it around.

Actually, isn't this generating at::<dispatch_key> functions for every dispatch key, and then only providing headers for the specific keys we want (cpu/cuda)?

Technically yes, but in reality only CPU and CUDA are supported by structured, so there isn't actually any wastage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. The alternative is to split up the implementations in codegen. I'm ambivalent about this; so if someone feels strongly I'll swap it around.

Thought it was worth calling out, but I'm ambivalent as well :)

Technically yes, but in reality only CPU and CUDA are supported by structured, so there isn't actually any wastage.

Ah right, yeah

{cpp_sig.defn()} {{
return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
}}
}} // namespace {self.dispatch_key.lower()}
"""

elif self.target is Target.REGISTRATION:
Expand All @@ -468,9 +481,6 @@ def gen_one(f: NativeFunction) -> Optional[str]:

@method_with_native_function
def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
# for mypy type refinement; would be fixed by TODO on target
assert self.target is not Target.DECLARATION

if self.dispatch_key not in f.dispatch:
return None
if f.manual_kernel_registration:
Expand All @@ -484,7 +494,9 @@ def gen_unstructured(self, f: NativeFunction) -> Optional[str]:
args = native.arguments(f.func)
args_str = ', '.join(a.defn() for a in args)

if self.target is Target.DEFINITION:
if self.target is Target.DECLARATION:
return ''
elif self.target is Target.DEFINITION:
impl_name = f"at::native::{f.dispatch[self.dispatch_key]}"

args_exprs_str = ', '.join(a.name for a in args)
Expand Down Expand Up @@ -768,7 +780,7 @@ def __call__(self, f: NativeFunction) -> Optional[str]:
return None

name = native.name(f.func)
native_sig = NativeSignature.from_schema(f.func)
native_sig = NativeSignature(f.func)

if not any(isinstance(a.argument, TensorOptionsArguments) for a in native_sig.arguments()):
return None
Expand Down Expand Up @@ -1284,15 +1296,19 @@ def make_file_manager(install_dir: str) -> FileManager:
# kernels
"Meta",
]
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/CPUFunctions.h/{dispatch key}Functions.h/ or whatever

# for them; this is the set
functions_keys = {
"CPU",
"CUDA",
}
if options.backend_whitelist:
dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or k in options.backend_whitelist]

for dispatch_key in dispatch_keys:
cpp_template = 'RegisterDispatchKey.cpp'

fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm

fm.write_with_template(f'Register{dispatch_key}.cpp', cpp_template, lambda: {
fm.write_with_template(f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: {
'extra_cuda_headers': extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else '',
'legacy_th_headers':
'#include <ATen/LegacyTHFunctionsCPU.h>' if dispatch_key == "CPU" else
Expand All @@ -1308,6 +1324,16 @@ def make_file_manager(install_dir: str) -> FileManager:
grouped_native_functions
)),
})

if dispatch_key in functions_keys:
fm.write_with_template(f'{dispatch_key}Functions.h', 'DispatchKeyFunctions.h', lambda: {
'dispatch_namespace': dispatch_key.lower(),
'dispatch_declarations': list(concatMap(
RegisterDispatchKey(dispatch_key, Target.DECLARATION, selector, rocm=options.rocm),
grouped_native_functions
)),
})

del fm

# BackendSelect is generated specially
Expand Down
34 changes: 3 additions & 31 deletions torch/csrc/jit/runtime/static/ops.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/runtime/static/ops.h>

#include <ATen/CPUFunctions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/quantized/cpu/qembeddingbag.h>
#include <torch/csrc/jit/ir/ir.h>
Expand Down Expand Up @@ -50,34 +51,6 @@ bool canRunNatively(Node* n) {
return true;
}

// TODO: PLEASE DON'T COPY PASTE THIS, this is copy pasted
// generated code to unblock, need to make this nicer
struct static_add final : public at::native::structured_add_out {
static_add(at::Tensor& output) : output_(output) {}
void set_output(
int64_t output_idx,
at::IntArrayRef sizes,
at::IntArrayRef strides,
at::TensorOptions options,
at::DimnameList names) override {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx == 0);
// NB: do NOT use resize_output as it will complain if not zero sized.
at::native::resize_(output_, sizes);
if (!strides.empty()) {
TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
output_.as_strided_(sizes, strides);
} else if (options.memory_format_opt().has_value()) {
output_.unsafeGetTensorImpl()->empty_tensor_restride(
*options.memory_format_opt());
}
}
const at::Tensor& maybe_get_output(int64_t output_idx) override {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx == 0);
return output_;
}
at::Tensor& output_;
};

REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
auto& in0_t = p_node->Input(0).toTensor();
Expand All @@ -87,9 +60,8 @@ REGISTER_OPERATOR_FUNCTOR(aten::add, aten_add, [](Node* n) -> SROperator {
p_node->Output(0) = create_empty_from(in0_t);
}
auto& out_t = p_node->Output(0).toTensor();
static_add op{out_t};
op.meta(in0_t, in1_t, in2_s);
op.impl(in0_t, in1_t, in2_s, out_t);
out_t.resize_({0});
at::cpu::add_out(out_t, in0_t, in1_t, in2_s);
};
});

Expand Down