Skip to content

Commit b4f5efa

Browse files
ezyangfacebook-github-bot
authored andcommitted
Structured kernels generate Meta registrations (#48116)
Summary: Pull Request resolved: #48116 If you port kernels to be structured, you get Meta kernels automatically generated for you. This is one payoff of structured kernels. Code generation was mercifully really simple, although at risk of "swiss cheese" syndrome: there's two new conditionals in the codegen to tweak behavior when generating for meta keys. It's not too bad right now but there's a risk of things getting out of hand. One way to rationalize the logic here would be to transmit "TensorMeta-ness" inside the TensorOptions (so tensor_from_meta can deal with it); then the "Meta" kernel magic would literally just be generating empty out_impls to call after all the scaffolding is done. But I didn't do this because it seemed like it would be more annoying short term. Also had to teach resize_ to work on meta tensors, since we use them to implement the out kernels. Signed-off-by: Edward Z. Yang <[email protected]> Test Plan: Imported from OSS Reviewed By: bhosmer, ailzhang Differential Revision: D25056640 Pulled By: ezyang fbshipit-source-id: f8fcfa0dbb58a94d9b4196748f56e155f83b1521
1 parent 47db191 commit b4f5efa

File tree

7 files changed

+67
-12
lines changed

7 files changed

+67
-12
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ genrule(
131131
"aten/src/ATen/RegisterQuantizedCPU.cpp",
132132
"aten/src/ATen/RegisterSparseCPU.cpp",
133133
"aten/src/ATen/RegisterMath.cpp",
134+
"aten/src/ATen/RegisterMeta.cpp",
134135
"aten/src/ATen/RegisterDefaultBackend.cpp",
135136
"aten/src/ATen/RegisterSchema.cpp",
136137
"aten/src/ATen/Functions.h",

aten/src/ATen/TensorMeta.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ struct TensorMeta {
1414
: sizes(_sizes), options(_options) {}
1515
};
1616

17+
inline Tensor meta_tensor_from_meta(const TensorMeta& meta) {
18+
// TODO: eliminate indirection
19+
return at::empty_meta(meta.sizes, meta.options);
20+
}
21+
1722
inline Tensor tensor_from_meta(const TensorMeta& meta) {
1823
// TODO: eliminate indirection
1924
return at::empty(meta.sizes, meta.options);

aten/src/ATen/native/Resize.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ Tensor& resize_as_(
7777
Tensor& resize_(
7878
Tensor& self,
7979
IntArrayRef size,
80-
c10::optional<MemoryFormat> optional_memory_format) {
80+
c10::optional<MemoryFormat> optional_memory_format,
81+
bool resize_storage) {
8182
if (self.has_names()) {
8283
return resize_named_tensor_(self, size, optional_memory_format);
8384
}
8485
auto* self_ = self.unsafeGetTensorImpl();
85-
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt);
86+
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt, resize_storage);
8687
if (optional_memory_format.has_value()) {
8788
auto memory_format =
8889
optional_memory_format.value();
@@ -95,5 +96,20 @@ Tensor& resize_(
9596
return self;
9697
}
9798

99+
Tensor& resize_(
100+
Tensor& self,
101+
IntArrayRef size,
102+
c10::optional<MemoryFormat> optional_memory_format) {
103+
return resize_(self, size, optional_memory_format, /*resize_storage=*/true);
104+
}
105+
106+
Tensor& resize_meta_(
107+
Tensor& self,
108+
IntArrayRef size,
109+
c10::optional<MemoryFormat> optional_memory_format) {
110+
// meta tensors don't have storage, so don't resize them
111+
return resize_(self, size, optional_memory_format, /*resize_storage=*/false);
112+
}
113+
98114
} // namespace native
99115
} // namespace at

aten/src/ATen/native/Resize.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, int64_t new_size)
4343
inline TensorImpl* resize_impl_cpu_(
4444
TensorImpl* self,
4545
IntArrayRef size,
46-
c10::optional<IntArrayRef> stride) {
46+
c10::optional<IntArrayRef> stride,
47+
bool resize_storage = true) {
4748
if (self->sizes() == size && (!stride || self->strides() == stride)) {
4849
return self;
4950
}
@@ -57,7 +58,9 @@ inline TensorImpl* resize_impl_cpu_(
5758
self->set_sizes_contiguous(size);
5859
storage_size = self->numel();
5960
}
60-
maybe_resize_storage_cpu(self, storage_size);
61+
if (resize_storage) {
62+
maybe_resize_storage_cpu(self, storage_size);
63+
}
6164

6265
return self;
6366
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,7 @@
16931693
CPU: resize_
16941694
CUDA: resize_cuda_
16951695
QuantizedCPU: quantized_resize_cpu_
1696+
Meta: resize_meta_
16961697

16971698
- func: empty_quantized(int[] size, Tensor qtensor) -> Tensor
16981699
use_c10_dispatcher: full

test/test_torch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2540,6 +2540,22 @@ def test_empty_meta(self):
25402540
z = x + y
25412541
self.assertEqual(z.size(), (2 ** 20, 2 ** 20))
25422542

2543+
def test_upsample_nearest1d_meta(self):
2544+
# TODO: this is not a sustainable way of testing meta functions,
2545+
# but I want some quick scaffolding first before a more
2546+
# integrated testing strategy
2547+
# NB: Can't make the exponent too big, or it will overflow
2548+
# signed 64-bit integer
2549+
x = torch.empty_meta(2 * 10 ** 8, 3, 2 * 10 ** 8)
2550+
z = torch.nn.functional.interpolate(x, scale_factor=2)
2551+
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
2552+
2553+
# interpolate doesn't seem to support out=
2554+
# (not sure why passing None here doesn't work? How strange...)
2555+
z = torch.empty_meta(0)
2556+
torch._C._nn.upsample_nearest1d(x, (4 * 10 ** 8,), 2, out=z)
2557+
self.assertEqual(z.size(), (2 * 10 ** 8, 3, 4 * 10 ** 8))
2558+
25432559
def test_normal_shape(self):
25442560
warned = False
25452561
for device in torch.testing.get_all_device_types():

tools/codegen/gen.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,11 @@ def __call__(self, f: Union[StructuredNativeFunctions, NativeFunction]) -> List[
244244
assert_never(f)
245245

246246
def gen_structured(self, g: StructuredNativeFunctions) -> List[str]:
247-
if self.dispatch_key not in g.out.dispatch:
247+
if self.dispatch_key == 'Meta':
248+
assert self.dispatch_key not in g.out.dispatch, \
249+
"Do not explicitly specify Meta dispatch key on structured " \
250+
"functions, they will be automatically generated for you"
251+
elif self.dispatch_key not in g.out.dispatch:
248252
return []
249253

250254
# Inner helper function to close over g
@@ -272,14 +276,15 @@ def gen_one(f: NativeFunction) -> Optional[str]:
272276
sig = NativeSignature.from_schema(f.func)
273277

274278
if self.target is Target.DEFINITION:
275-
out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}"
276-
277279
# TODO: work a little harder to generate fresh names for 'result'
278280
# TODO: less praying that I picked the right argument name for 'self'
279281

280282
if k is SchemaKind.functional:
281283
out_expr = "result"
282-
prologue = "auto result = tensor_from_meta(meta_result);"
284+
if self.dispatch_key == "Meta":
285+
prologue = "auto result = meta_tensor_from_meta(meta_result);"
286+
else:
287+
prologue = "auto result = tensor_from_meta(meta_result);"
283288
elif k is SchemaKind.inplace:
284289
out_expr = "self"
285290
prologue = "// TODO: consistency check assert"
@@ -294,6 +299,12 @@ def gen_one(f: NativeFunction) -> Optional[str]:
294299
{out_expr}.resize_(meta_result.sizes);
295300
"""
296301

302+
if self.dispatch_key == "Meta":
303+
out_impl_call = "// meta function does nothing"
304+
else:
305+
out_impl_name = f"at::native::{g.out.dispatch[self.dispatch_key]}"
306+
out_impl_call = f"{out_impl_name}({out_expr}, {functional_exprs});"
307+
297308
device_guard = ""
298309

299310
if is_generic_dispatch_key(self.dispatch_key) or is_cuda_dispatch_key(self.dispatch_key):
@@ -317,7 +328,7 @@ def gen_one(f: NativeFunction) -> Optional[str]:
317328
{device_guard}
318329
auto meta_result = meta::{meta_name}({functional_exprs});
319330
{prologue}
320-
{out_impl_name}({out_expr}, {functional_exprs});
331+
{out_impl_call}
321332
return {out_expr};
322333
}}
323334
"""
@@ -1048,6 +1059,7 @@ def flatten_pre_group(d: Dict[SchemaKind, NativeFunction]) -> Sequence[Union[Nat
10481059

10491060
# TODO: how come ValuesView isn't a Sequence lol
10501061
grouped_native_functions = list(concatMap(flatten_pre_group, list(pre_grouped_native_functions.values())))
1062+
structured_native_functions = [g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)]
10511063

10521064
template_dir = os.path.join(options.source_path, "templates")
10531065

@@ -1093,6 +1105,9 @@ def make_file_manager(install_dir: str) -> FileManager:
10931105
"QuantizedCUDA",
10941106
"Math",
10951107
"DefaultBackend",
1108+
# Meta is a magic key: it is automatically generated for structured
1109+
# kernels
1110+
"Meta",
10961111
]
10971112
if options.backend_whitelist:
10981113
dispatch_keys = [k for k in dispatch_keys if is_generic_dispatch_key(k) or k in options.backend_whitelist]
@@ -1129,9 +1144,7 @@ def make_file_manager(install_dir: str) -> FileManager:
11291144
})
11301145

11311146
cpu_fm.write('MetaFunctions.h', lambda: {
1132-
'declarations':
1133-
list(mapMaybe(compute_meta_function_declaration,
1134-
(g for g in grouped_native_functions if isinstance(g, StructuredNativeFunctions)))),
1147+
'declarations': list(map(compute_meta_function_declaration, structured_native_functions)),
11351148
})
11361149

11371150
schema_selector = selector

0 commit comments

Comments
 (0)