Skip to content

Commit cb43170

Browse files
VitalyFedyuninfacebook-github-bot
authored andcommitted
Add memory format support to the resize_ op. (#28292)
Summary: Pull Request resolved: #28292 Allows to simplify patterns like: 1. output.resize_({sizeB, sizeC, osizeH, osizeW}).as_strided_({sizeB, sizeC, osizeH, osizeW}, {sizeC*osizeH*osizeW, 1, osizeW*sizeC, sizeC}); 2. output.resize_({nbatch, nInputPlane, outputHeight, outputWidth}); indices.resize_({nbatch, nInputPlane, outputHeight, outputWidth}); output.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); indices.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); 3. gradInput.resize_as_(input); gradInput.unsafeGetTensorImpl()->empty_tensor_restride(memory_format); Test Plan: Imported from OSS Differential Revision: D18044978 Pulled By: VitalyFedyunin fbshipit-source-id: bbf67c25f9cf88bc6e949089a3b247df50f86dc4
1 parent a7df369 commit cb43170

File tree

8 files changed

+66
-15
lines changed

8 files changed

+66
-15
lines changed

aten/src/ATen/native/Resize.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,27 @@
55

66
namespace at { namespace native {
77

8-
Tensor& resize_cpu_(Tensor& self, IntArrayRef size) {
8+
Tensor& resize_cpu_(
9+
Tensor& self,
10+
IntArrayRef size,
11+
c10::optional<MemoryFormat> optional_memory_format) {
912
#ifdef BUILD_NAMEDTENSOR
1013
if (self.has_names()) {
11-
return resize_named_tensor_(self, size);
14+
return resize_named_tensor_(self, size, optional_memory_format);
1215
}
1316
#endif
1417
auto* self_ = self.unsafeGetTensorImpl();
1518
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt);
1619
self_->maybe_zero_dim(size.size() == 0);
20+
if (optional_memory_format.has_value()) {
21+
auto memory_format =
22+
optional_memory_format.value();
23+
TORCH_CHECK(
24+
memory_format != MemoryFormat::Preserve,
25+
"Unsupported memory format",
26+
memory_format);
27+
self_->empty_tensor_restride(memory_format);
28+
}
1729
return self;
1830
}
1931

aten/src/ATen/native/ResizeCommon.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,27 @@
66
namespace at { namespace native {
77

88
#ifdef BUILD_NAMEDTENSOR
9-
inline Tensor& resize_named_tensor_(Tensor& self, IntArrayRef size) {
9+
inline Tensor& resize_named_tensor_(
10+
Tensor& self,
11+
IntArrayRef size,
12+
c10::optional<MemoryFormat> optional_memory_format) {
1013
TORCH_INTERNAL_ASSERT(self.has_names());
1114
TORCH_CHECK(
1215
self.sizes() == size,
1316
"Cannot resize named tensor with resize_ or resize_as_ (tried to resize "
14-
"Tensor", self.names(), " with size ", self.sizes(), " to ", size,
17+
"Tensor",
18+
self.names(),
19+
" with size ",
20+
self.sizes(),
21+
" to ",
22+
size,
1523
"). This may be caused by passing a named tensor ",
1624
"as an `out=` argument; please ensure that the sizes are the same. ");
25+
TORCH_CHECK(
26+
!optional_memory_format.has_value(),
27+
"Unsupported memory format for named tensor resize ",
28+
optional_memory_format.value());
1729
return self;
1830
}
1931
#endif
20-
2132
}}

aten/src/ATen/native/cuda/Resize.cu

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,31 @@
44
#include <ATen/native/cuda/Resize.cuh>
55
#include <ATen/native/ResizeCommon.h>
66

7-
namespace at { namespace native {
7+
namespace at {
8+
namespace native {
89

9-
Tensor& resize_cuda_(Tensor& self, IntArrayRef size) {
10+
Tensor& resize_cuda_(
11+
Tensor& self,
12+
IntArrayRef size,
13+
c10::optional<MemoryFormat> optional_memory_format) {
1014
#ifdef BUILD_NAMEDTENSOR
1115
if (self.has_names()) {
12-
return resize_named_tensor_(self, size);
16+
return resize_named_tensor_(self, size, optional_memory_format);
1317
}
1418
#endif
1519
auto* self_ = self.unsafeGetTensorImpl();
1620
resize_impl_cuda_(self_, size, /*strides=*/c10::nullopt);
1721
self_->maybe_zero_dim(size.size() == 0);
22+
if (optional_memory_format.has_value()) {
23+
auto memory_format =
24+
optional_memory_format.value();
25+
TORCH_CHECK(
26+
memory_format != MemoryFormat::Preserve,
27+
"Unsupported memory format",
28+
memory_format);
29+
self_->empty_tensor_restride(memory_format);
30+
}
1831
return self;
1932
}
20-
21-
}}
33+
} // namespace native
34+
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@
10151015
CPU: empty_per_channel_affine_quantized_other_backends_stub
10161016
QuantizedCPU: empty_per_channel_affine_quantized_cpu
10171017

1018-
- func: resize_(Tensor(a!) self, int[] size) -> Tensor(a!)
1018+
- func: resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
10191019
supports_named_tensor: True
10201020
variants: method
10211021
device_guard: False

aten/src/ATen/native/quantized/cpu/tensor_operators.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,14 @@ AT_FORALL_OPERATORS(DEFINE_COMPARATOR)
5959
#undef AT_FORALL_OPERATORS
6060
#undef DEFINE_COMPARATOR
6161

62-
Tensor& quantized_resize_cpu_(Tensor& self, IntArrayRef size) {
62+
Tensor& quantized_resize_cpu_(
63+
Tensor& self,
64+
IntArrayRef size,
65+
c10::optional<MemoryFormat> optional_memory_format) {
66+
TORCH_CHECK(
67+
!optional_memory_format.has_value(),
68+
"Unsupported memory format for quantized tensor resize ",
69+
optional_memory_format.value());
6370
auto qscheme = self.quantizer()->qscheme();
6471
TORCH_CHECK(
6572
qscheme == QScheme::PER_TENSOR_AFFINE ||

test/test_torch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12065,6 +12065,11 @@ def test_memory_format_resize_as(self, device):
1206512065
flat.resize_as_(nhwc, memory_format=torch.preserve_format)
1206612066
self.assertTrue(flat.is_contiguous(memory_format=torch.channels_last))
1206712067

12068+
def test_memory_format_resize_(self, device):
12069+
flat = torch.randn(10 * 3 * 32 * 32, device=device)
12070+
flat.resize_((10, 3, 32, 32), memory_format=torch.channels_last)
12071+
self.assertTrue(flat.is_contiguous(memory_format=torch.channels_last))
12072+
1206812073
def test_memory_format_empty_like(self, device):
1206912074
x = torch.randn(4, 3, 8, 8, device=device)
1207012075
nhwc = x.contiguous(memory_format=torch.channels_last)

tools/autograd/templates/VariableType.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ namespace VariableType {
6969
int64_t output_nr(const Tensor & self);
7070
int64_t _version(const Tensor & self);
7171
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking);
72-
Tensor & resize_(Tensor & self, IntArrayRef size);
72+
Tensor & resize_(Tensor & self, IntArrayRef size, c10::optional<MemoryFormat> optional_memory_format);
7373
Tensor & resize_as_(Tensor & self, const Tensor & the_template, c10::optional<MemoryFormat> optional_memory_format);
7474
Tensor detach(const Tensor & self);
7575
Tensor & detach_(Tensor & self);

torch/csrc/autograd/VariableTypeManual.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
161161
return self;
162162
}
163163

164-
Tensor & resize_(Tensor & self, IntArrayRef size) {
164+
Tensor& resize_(
165+
Tensor& self,
166+
IntArrayRef size,
167+
c10::optional<MemoryFormat> optional_memory_format) {
165168
auto& self_ = unpack(self, "self", 0);
166169
if (as_variable_ref(self).requires_grad()) {
167170
AT_ERROR("cannot resize variables that require grad");
@@ -173,7 +176,7 @@ Tensor & resize_(Tensor & self, IntArrayRef size) {
173176
}
174177
{
175178
at::AutoNonVariableTypeMode non_var_type_mode(true);
176-
self_.resize_(size);
179+
self_.resize_(size, std::move(optional_memory_format));
177180
}
178181
return self;
179182
}

0 commit comments

Comments
 (0)