Skip to content

Commit 16b8e6a

Browse files
ezyangfacebook-github-bot
authored andcommitted
Class-based structured kernels, with migration of add to framework (#48718)
Summary: Pull Request resolved: #48718 This PR rewrites structured kernels to do the class-based mechanism (instead of defining a meta and impl function, they are methods on a class), and adds enough customizability on the class to support TensorIterator. To show it works, add is made a structured kernel. Don't forget to check pytorch/rfcs#9 for a mostly up-to-date high level description of what's going on here. High level structure of this PR (the order you should review files): * TensorMeta.h - TensorMeta is deleted entirely; instead, meta functions will call `set_output` to allocate/resize their outputs. MetaBase gets a new `maybe_get_output` virtual method for retrieving the (possibly non-existent) output tensor in a meta function; this makes it easier to do special promotion behavior, e.g., as in TensorIterator. * TensorIterator.cpp - Two major changes: first, we add TensorIteratorBase::set_output, which is a "light" version of TensorIterator::set_output; it sets up the internal data structures in TensorIterator, but it doesn't do allocation (that is assumed to have been handled by the structured kernels framework). The control flow here is someone will call the subclassed set_output, which will allocate output, and then we will call the parent class (TensorIteratorBase) to populate the fields in TensorIterator so that other TensorIterator phases can keep track of it. Second, we add some tests for meta tensors, and skip parts of TensorIterator which are not necessary when data is not available. * tools/codegen/model.py - One new field in native_functions.yaml, structured_inherits. This lets you override the parent class of a structured meta class; normally it's MetaBase, but you can make it point at TensorIteratorBase instead for TensorIterator based kernels * tools/codegen/gen.py - Now generate all of the classes we promised. It's kind of hairy because this is the first draft. Check the RFC for what the output looks like, and then follow the logic here. There are some complications: I need to continue to generate old style wrapper functions even if an operator is structured, because SparseCPU/SparseCUDA/etc won't actually use structured kernels to start. The most complicated code generation is the instantiation of `set_output`, which by in large replicates the logic in `TensorIterator::set_output`. This will continue to live in codegen for the forseeable future as we would like to specialize this logic per device. * aten/src/ATen/native/UpSampleNearest1d.cpp - The previous structured kernel is ported to the new format. The changes are very modest. * aten/src/ATen/native/BinaryOps.cpp - Add is ported to structured. TODO: * Work out an appropriate entry point for static runtime, since native:: function stubs no longer are generated * Refactor TensorIteratorConfig construction into helper functions, like before * Make Tensor-Scalar addition structured to fix perf regression * Fix `verify_api_visibility.cpp` * Refactor tools/codegen/gen.py for clarity * Figure out why header changes resulted in undefined reference to `at::Tensor::operator[](long) const` Signed-off-by: Edward Z. Yang <[email protected]> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D25278031 Pulled By: ezyang fbshipit-source-id: 57c43a6e5df21929b68964d485995fbbae4d1f7b
1 parent a6fa3b2 commit 16b8e6a

File tree

21 files changed

+506
-210
lines changed

21 files changed

+506
-210
lines changed

aten/src/ATen/TensorIterator.cpp

Lines changed: 128 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -402,14 +402,14 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
402402
// TODO: reuse temporaries when possible (e.g. for inplace operations)
403403
if (common_device == kCPU) {
404404
// Casts to outputs by creating temporaries of the correct dtype (if needed)
405-
if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_) {
405+
// NB: we skip this on is_meta_, because the temporary allocation here is
406+
// unnecessary if we aren't going to actually do the compute
407+
if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) {
406408
TORCH_INTERNAL_ASSERT(op.tensor.defined());
409+
// Marker [Output original_tensor is set]
407410
op.original_tensor = op.tensor;
408411
// NB: do NOT use set_output here, as the temporary is NOT a true output;
409412
// op.tensor is the true output and it was pre-provided for us.
410-
// TODO: When we extend this to work with meta tensors, we'll need to
411-
// skip this temporary allocation in that case (because it's
412-
// unnecessary)
413413
// TODO: The logic for cast_outputs will need to be handled by the
414414
// structured kernels implementation. What probably should happen
415415
// is that we pass in the inferred dtype into the out kernel, and
@@ -488,10 +488,10 @@ void TensorIteratorBase::allocate_or_resize_outputs() {
488488
set_output(i, tensor_shape, tensor_stride, op.options(), names_);
489489
}
490490
op.current_dtype = op.target_dtype;
491-
} else if (op.tensor.defined() && !names_.empty()) {
492-
// Even if we don't resize, we may still propagate names, esp
493-
// if we were doing an inplace operation
494-
namedinference::propagate_names(op.tensor, names_);
491+
} else if (op.tensor.defined()) {
492+
// Even if we don't resize, we still need to tell set_output about
493+
// the output, so that we properly set guard and propagate names
494+
set_output(i, op.tensor.sizes(), {}, op.tensor.options(), names_);
495495
}
496496
}
497497
}
@@ -765,6 +765,8 @@ void TensorIteratorBase::cast_outputs() {
765765
for (auto& op : operands_) {
766766
if (op.is_output && op.original_tensor.defined() &&
767767
op.original_tensor.scalar_type() != op.current_dtype) {
768+
// TODO: Now that set_output resizes both the original_tensor
769+
// and tensor, this condition should no longer ever be true
768770
if (op.original_tensor.sizes() != op.tensor.sizes()){
769771
op.original_tensor.resize_as_(op.tensor).as_strided_(op.tensor.sizes(), op.tensor.strides());
770772
}
@@ -808,18 +810,22 @@ void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indic
808810
}
809811
}
810812

811-
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a,
812-
const Tensor& b) {
813-
return TensorIteratorConfig()
814-
.set_check_mem_overlap(true)
815-
.add_output(out)
816-
.add_input(a)
817-
.add_input(b)
818-
.allow_cpu_scalars(true)
819-
.promote_inputs_to_common_dtype(true)
820-
.cast_common_dtype_to_outputs(true)
821-
.enforce_safe_casting_to_output(true)
822-
.build();
813+
void TensorIteratorBase::build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b) {
814+
build(TensorIteratorConfig()
815+
.set_check_mem_overlap(true)
816+
.add_output(out)
817+
.add_input(a)
818+
.add_input(b)
819+
.allow_cpu_scalars(true)
820+
.promote_inputs_to_common_dtype(true)
821+
.cast_common_dtype_to_outputs(true)
822+
.enforce_safe_casting_to_output(true));
823+
}
824+
825+
TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, const Tensor& b) {
826+
TensorIterator iter;
827+
iter.build_binary_op(out, a, b);
828+
return iter;
823829
}
824830

825831
// Helper to construct a binary op that promotes integer inputs to float.
@@ -940,6 +946,13 @@ TensorIterator TensorIterator::reduce_op(Tensor& out1, Tensor& out2, const Tenso
940946

941947
void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) {
942948
for (auto& tensor: config.tensors_) {
949+
// If *any* of the arguments is a meta tensor, the overall
950+
// computation is a meta computation (don't do any work,
951+
// just compute output information). This aligns with
952+
// our multiple dispatch semantics.
953+
if (tensor.is_meta()) {
954+
is_meta_ = true;
955+
}
943956
operands_.emplace_back(std::move(tensor));
944957
}
945958
num_outputs_ = config.num_outputs_;
@@ -988,6 +1001,10 @@ void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config
9881001
if (!config.check_mem_overlap_) {
9891002
return;
9901003
}
1004+
if (is_meta_) {
1005+
// We don't have pointer addresses, cannot check for overlap!
1006+
return;
1007+
}
9911008
for (int i = 0; i < num_outputs_; i++) {
9921009
const auto& output = operands_[i].tensor;
9931010
if (!output.defined()) continue;
@@ -1265,9 +1282,11 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
12651282
// allocate the output tensor if it's not provided
12661283
allocate_or_resize_outputs();
12671284
// coalesce adjacent dimensions when possible
1268-
coalesce_dimensions();
1285+
if (!is_meta_) coalesce_dimensions();
12691286
}
12701287

1288+
if (is_meta_) return;
1289+
12711290
for (auto& op : operands_) {
12721291
TORCH_INTERNAL_ASSERT(op.tensor.defined());
12731292
op.data = op.tensor.data_ptr();
@@ -1281,14 +1300,92 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
12811300
view_offsets_ = DimVector(ndim_offsets, 0);
12821301
}
12831302

1303+
// This is the structured kernels implementation of set_output. It is
1304+
// NEVER actually called directly; instead, a subclass of TensorIteratorBase
1305+
// will override set_output to actually do the operation, and then call
1306+
// set_output on the TensorIteratorBase to setup TI's metadata.
1307+
// The precondition for this function is that maybe_get_output() now
1308+
// unconditionally returns a real Tensor (prior to output setting,
1309+
// this function may return an undefined tensor.)
1310+
void TensorIteratorBase::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
1311+
auto& op = operands_[output_idx];
1312+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
1313+
const auto& t = maybe_get_output(output_idx);
1314+
TORCH_INTERNAL_ASSERT(t.defined());
1315+
if (!op.tensor.defined()) {
1316+
op.tensor = t;
1317+
op.current_dtype = op.target_dtype;
1318+
} else if (op.will_resize) {
1319+
if (op.original_tensor.defined()) {
1320+
// OK, so this is pretty weird. To understand how we can end up in
1321+
// this situation, first look at Marker [Output original_tensor is set].
1322+
// That is the sole site where original_tensor may be set on an
1323+
// output operand. Essentially, when we are given an explicit output
1324+
// tensor whose dtype doesn't match the computed common dtype from
1325+
// the input operands, we do a switcheroo: we replace the (incorrectly
1326+
// typed) output tensor with a correctly typed, *temporary* tensor,
1327+
// and remember the original tensor in original_tensor (which will
1328+
// then get written back to when we cast_outputs).
1329+
//
1330+
// Now, what if the given output tensor also happened to be zero
1331+
// size (meaning that we will_resize it)? Well, at the call site
1332+
// above, we don't necessarily(*) know what the correct shape should
1333+
// be, so we give the temporary tensor the same shape as the original.
1334+
// At the time of set_output is when we DO know what the correct size
1335+
// is, and the subclass's implementation of set_output in structured class
1336+
// responsible for resizing original_tensor. But we still have this
1337+
// incorrectly sized temporary output which the structured subclass
1338+
// knows nothing about, so we are obligated to also resize it here.
1339+
//
1340+
// This is a slight memory pessimization, because previously
1341+
// original_tensor only got resized at the end of the computation, rather
1342+
// than at the beginning (as happens here). However, the peak memory
1343+
// usage is the same, since you need to materialize both original tensor
1344+
// and temporary tensor to do the copy.
1345+
//
1346+
// (*) Actually, technically, we probably do know what the shape
1347+
// should be, since we do shape computation before dtype computation.
1348+
// So hypothetically we could figure out what the correct shape is
1349+
// at that point in time and directly allocate the temporary at
1350+
// the right size.
1351+
//
1352+
// But a better solution is to delay allocation of temporaries until
1353+
// after TensorIterator builder, waiting until we actually want
1354+
// to do the computation. That would also remove the necessity
1355+
// for the is_meta_ test.
1356+
TORCH_INTERNAL_ASSERT(op.original_tensor.is_same(t));
1357+
TORCH_INTERNAL_ASSERT(!op.tensor.is_same(t));
1358+
at::native::resize_output(op.tensor, sizes);
1359+
if (!strides.empty()) {
1360+
TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
1361+
op.tensor.as_strided_(sizes, strides);
1362+
} else if (options.memory_format_opt().has_value()) {
1363+
op.tensor.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
1364+
}
1365+
}
1366+
}
1367+
}
1368+
1369+
// This is the "traditional" implementation of set_output. On TensorIterator
1370+
// instances, it is invoked directly from various call sites in this file. No
1371+
// funny business.
12841372
void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
1373+
// NB: intentionally no superclass call
12851374
auto& op = operands_[output_idx];
12861375
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
12871376
if (!op.tensor.defined()) {
12881377
if (strides.empty()) {
1289-
op.tensor = at::empty(sizes, options);
1378+
if (is_meta_) {
1379+
op.tensor = at::empty_meta(sizes, options);
1380+
} else {
1381+
op.tensor = at::empty(sizes, options);
1382+
}
12901383
} else {
1291-
op.tensor = at::empty_strided(sizes, strides, options);
1384+
if (is_meta_) {
1385+
TORCH_INTERNAL_ASSERT(0, "meta strided not yet implemented");
1386+
} else {
1387+
op.tensor = at::empty_strided(sizes, strides, options);
1388+
}
12921389
}
12931390
op.current_dtype = op.target_dtype;
12941391
} else if (op.will_resize) {
@@ -1306,6 +1403,14 @@ void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayR
13061403
}
13071404
}
13081405

1406+
// Not actually used by anything (TensorIterator subclass calls
1407+
// its own implementation of set_output which knows exactly where
1408+
// all the outputs are), but we have to provide all pure virtual methods
1409+
// for MetaBase
1410+
const Tensor& TensorIterator::maybe_get_output(int64_t output_idx) {
1411+
return operands_[output_idx].tensor;
1412+
}
1413+
13091414
SplitUntil32Bit TensorIteratorBase::with_32bit_indexing() const {
13101415
return SplitUntil32Bit(*this);
13111416
}

aten/src/ATen/TensorIterator.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ struct CAFFE2_API TensorIteratorBase : public impl::MetaBase {
297297
return true;
298298
}
299299

300+
void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override;
301+
302+
void build_binary_op(const Tensor& out, const Tensor& a, const Tensor& b);
303+
300304
protected:
301305
// Mutable reference as it moves tensors out of TensorIteratorConfig
302306
void populate_operands(TensorIteratorConfig&);
@@ -399,6 +403,9 @@ struct CAFFE2_API TensorIteratorBase : public impl::MetaBase {
399403

400404
// From TensorIteratorConfig
401405
bool is_reduction_ = false;
406+
407+
/// Set by populate_operands(), says if we're handling meta tensors
408+
bool is_meta_ = false;
402409
};
403410

404411
struct CAFFE2_API TensorIterator final : public TensorIteratorBase {
@@ -415,6 +422,7 @@ struct CAFFE2_API TensorIterator final : public TensorIteratorBase {
415422
static TensorIterator reduce_op(Tensor& out, const Tensor& a);
416423
static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a);
417424

425+
const Tensor& maybe_get_output(int64_t output_idx) override;
418426
void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override;
419427
};
420428

aten/src/ATen/TensorMeta.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,5 @@
11
#include <ATen/TensorMeta.h>
2-
#include <ATen/ATen.h>
32

43
namespace at {
54

6-
Tensor meta_tensor_from_meta(const TensorMeta& meta) {
7-
// TODO: eliminate indirection
8-
return at::empty_meta(meta.sizes, meta.options);
9-
}
10-
11-
Tensor tensor_from_meta(const TensorMeta& meta) {
12-
// TODO: eliminate indirection
13-
return at::empty(meta.sizes, meta.options);
14-
}
15-
16-
// Analogous to self.new_empty(sizes)
17-
TensorMeta new_meta(const Tensor& self, IntArrayRef sizes) {
18-
return TensorMeta(sizes, self.options());
19-
}
20-
215
} // namespace at

aten/src/ATen/TensorMeta.h

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,54 @@ class Tensor;
1010

1111
namespace impl {
1212

13-
struct MetaBase {
13+
// Use this to define the prototype for a meta function. There are two
14+
// versions; one that takes one argument (just the operator name), or FUNC2
15+
// variant that takes two arguments (operator name and overload name).
16+
//
17+
// Example usage:
18+
//
19+
// TORCH_META_FUNC2(add, Tensor) (
20+
// const Tensor& self, const Tensor& other
21+
// ) {
22+
// ... compute sizes and options ...
23+
// set_output(sizes, options);
24+
// }
25+
//
26+
#define TORCH_META_FUNC(name) void name::meta
27+
#define TORCH_META_FUNC2(name, overload) void name##_##overload::meta
28+
29+
// Use this to define the prototype for an implementation. This takes only
30+
// one argument, which is the name of the dispatch key entry you're
31+
// implementing.
32+
//
33+
// Example usage:
34+
//
35+
// TORCH_IMPL_FUNC(add_cpu) (
36+
// Tensor& result, const Tensor& self, const Tensor& other
37+
// ) {
38+
// ... do the actual implementation ...
39+
// }
40+
//
41+
#define TORCH_IMPL_FUNC(name) void structured_##name::impl
42+
43+
// Base class for all structured kernel classes. The set_output virtual
44+
// method is varied depending whether or not the operator is
45+
// functional/out/inplace, and could also be specialized for CPU/CUDA/etc
46+
// (although presently it isn't).
47+
//
48+
// A notable subclass of this interface is TensorIteratorBase.
49+
struct CAFFE2_API MetaBase {
1450
virtual void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) = 0;
51+
virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
1552
void set_output(IntArrayRef sizes, TensorOptions options) {
1653
set_output(0, sizes, {}, options, {});
1754
}
55+
// Returns a reference to an undefined tensor if there is no presupplied
56+
// output
57+
const Tensor& maybe_get_output() { return maybe_get_output(0); }
1858
virtual ~MetaBase() {}
1959
};
2060

2161
} // namespace impl
2262

23-
struct TensorMeta {
24-
DimVector sizes;
25-
// TODO: DimVector strides;
26-
TensorOptions options;
27-
28-
TensorMeta(IntArrayRef _sizes, TensorOptions _options)
29-
: sizes(_sizes), options(_options) {}
30-
};
31-
32-
CAFFE2_API Tensor meta_tensor_from_meta(const TensorMeta& meta);
33-
CAFFE2_API Tensor tensor_from_meta(const TensorMeta& meta);
34-
// Analogous to self.new_empty(sizes)
35-
CAFFE2_API TensorMeta new_meta(const Tensor& self, IntArrayRef sizes);
36-
3763
} // namespace at

0 commit comments

Comments
 (0)