Skip to content

Commit 3da2e09

Browse files
vfdev-5facebook-github-bot
authored andcommitted
Added antialias flag to interpolate (CPU only, bilinear) (#65142)
Summary: Description: - Added antialias flag to interpolate (CPU only) - forward and backward for bilinear mode - added tests ### Benchmarks <details> <summary> Forward pass, CPU. PTH interpolation vs PIL </summary> Cases: - PTH RGB 3 Channels, float32 vs PIL RGB uint8 (apply vs pears) - PTH 1 Channel, float32 vs PIL 1 Channel Float Code: https://gist.github.com/vfdev-5/b173761a567f2283b3c649c3c0574112 ``` # OMP_NUM_THREADS=1 python bench_interp_aa_vs_pillow.py Torch config: PyTorch built with: - GCC 9.3 - C++ Version: 201402 - OpenMP 201511 (a.k.a. OpenMP 4.5) - CPU capability usage: AVX2 - CUDA Runtime 11.1 - NVCC architecture flags: -gencode;arch=compute_75,code=sm_75 - CuDNN 8.0.5 - Build settings: BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.10.0, USE_CUDA=1, USE_CUDNN=1, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=0, USE_OPENMP=ON, Num threads: 1 [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (320, 196) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2.9 | 3.1 channels_last non-contiguous torch.float32 | 2.6 | 3.6 Times are in milliseconds (ms). [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (460, 220) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 3.4 | 4.0 channels_last non-contiguous torch.float32 | 3.4 | 4.8 Times are in milliseconds (ms). [------------------------ Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 96) -------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 1.6 | 1.8 channels_last non-contiguous torch.float32 | 1.6 | 1.9 Times are in milliseconds (ms). [----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (1200, 196) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 9.0 | 11.3 channels_last non-contiguous torch.float32 | 8.9 | 12.5 Times are in milliseconds (ms). [----------------------- Downsampling: torch.Size([1, 3, 906, 438]) -> (120, 1200) ------------------------] | Reference, PIL 8.3.2, mode: RGB | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------------------------- channels_first contiguous torch.float32 | 2.1 | 1.8 channels_last non-contiguous torch.float32 | 2.1 | 3.4 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (320, 196) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.2 | 1.0 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (460, 220) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 1.4 | 1.3 Times are in milliseconds (ms). [--------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 96) ---------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 719.9 | 599.9 Times are in microseconds (us). [-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (1200, 196) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 3.7 | 3.5 Times are in milliseconds (ms). [-------------- Downsampling: torch.Size([1, 1, 906, 438]) -> (120, 1200) --------------] | Reference, PIL 8.3.2, mode: F | 1.10.0a0+git1e87d91 1 threads: ------------------------------------------------------------------------------ contiguous torch.float32 | 834.4 | 605.7 Times are in microseconds (us). ``` </details> Code is moved from torchvision: pytorch/vision#4208 Pull Request resolved: #65142 Reviewed By: mrshenli Differential Revision: D32432405 Pulled By: jbschlosser fbshipit-source-id: b66c548347f257c522c36105868532e8bc1d4c6d
1 parent 143491e commit 3da2e09

File tree

14 files changed

+729
-56
lines changed

14 files changed

+729
-56
lines changed

aten/src/ATen/core/aten_interned_strings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,9 @@ _(aten, unsqueeze) \
739739
_(aten, upsample_bilinear2d) \
740740
_(aten, upsample_bilinear2d_backward) \
741741
_(aten, upsample_bilinear2d_forward) \
742+
_(aten, _upsample_bilinear2d_aa) \
743+
_(aten, _upsample_bilinear2d_aa_backward) \
744+
_(aten, _upsample_bilinear2d_aa_forward) \
742745
_(aten, upsample_bicubic2d) \
743746
_(aten, upsample_bicubic2d_backward) \
744747
_(aten, upsample_bicubic2d_forward) \

aten/src/ATen/native/UpSample.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input,
7272
using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
7373
using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
7474
using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
75+
using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
7576
using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
7677
using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
7778
DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
@@ -88,9 +89,11 @@ DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
8889
DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
8990
DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
9091
DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
92+
DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
9193
DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
9294
DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
9395
DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
96+
DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
9497
DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
9598
DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
9699

aten/src/ATen/native/UpSampleBilinear2d.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,45 @@ TORCH_META_FUNC(upsample_bilinear2d_backward) (
4848
set_output(input_size, grad_output.options().memory_format(grad_output.suggest_memory_format()));
4949
}
5050

51+
TORCH_META_FUNC(_upsample_bilinear2d_aa) (
52+
const Tensor& input, IntArrayRef output_size, bool align_corners, c10::optional<double> scales_h, c10::optional<double> scales_w
53+
) {
54+
auto full_output_size = native::upsample_2d_common_check(input.sizes(), output_size);
55+
56+
// Allow for empty batch size but not other dimensions
57+
TORCH_CHECK(
58+
input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
59+
"Non-empty 4D data tensor expected but got a tensor with sizes ",
60+
input.sizes());
61+
62+
set_output(full_output_size, input.options().memory_format(input.suggest_memory_format()));
63+
}
64+
65+
TORCH_META_FUNC(_upsample_bilinear2d_aa_backward) (
66+
const Tensor& grad_output,
67+
IntArrayRef output_size,
68+
IntArrayRef input_size,
69+
bool align_corners,
70+
c10::optional<double> scales_h,
71+
c10::optional<double> scales_w
72+
) {
73+
auto full_output_size = native::upsample_2d_common_check(input_size, output_size);
74+
75+
TORCH_CHECK(
76+
grad_output.dim() == 4,
77+
"Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim());
78+
79+
for (int i = 0; i < 4; ++i) {
80+
TORCH_CHECK(
81+
grad_output.size(i) == full_output_size[i],
82+
"Expected grad_output to have the same shape as output;",
83+
" output.size(", i, ") = ", full_output_size[i],
84+
" but got grad_output.size(", i, ") = ", grad_output.size(i));
85+
}
86+
87+
set_output(input_size, grad_output.options().memory_format(grad_output.suggest_memory_format()));
88+
}
89+
5190
} // namespace meta
5291

5392
namespace native {
@@ -76,6 +115,31 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_cpu) (
76115
upsample_bilinear2d_backward_kernel(kCPU, grad_input, grad_output, align_corners, scales_h, scales_w);
77116
}
78117

118+
119+
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_cpu) (
120+
const Tensor& input,
121+
IntArrayRef output_size,
122+
bool align_corners,
123+
c10::optional<double> scales_h,
124+
c10::optional<double> scales_w,
125+
const Tensor& output
126+
) {
127+
_upsample_bilinear2d_aa_kernel(kCPU, output, input, align_corners, scales_h, scales_w);
128+
}
129+
130+
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_cpu) (
131+
const Tensor& grad_output,
132+
IntArrayRef output_size,
133+
IntArrayRef input_size,
134+
bool align_corners,
135+
c10::optional<double> scales_h,
136+
c10::optional<double> scales_w,
137+
const Tensor& grad_input
138+
) {
139+
grad_input.zero_();
140+
_upsample_bilinear2d_aa_backward_kernel(kCPU, grad_input, grad_output, align_corners, scales_h, scales_w);
141+
}
142+
79143
using at::native::upsample::compute_output_size;
80144
using at::native::upsample::get_scale_value;
81145

@@ -102,8 +166,33 @@ Tensor upsample_bilinear2d_backward(
102166
return at::upsample_bilinear2d_backward(grad_output, osize, input_size, align_corners, scale_h, scale_w);
103167
}
104168

169+
Tensor _upsample_bilinear2d_aa(
170+
const Tensor& input,
171+
c10::optional<IntArrayRef> output_size,
172+
bool align_corners,
173+
c10::optional<ArrayRef<double>> scale_factors) {
174+
auto osize = compute_output_size(input.sizes(), output_size, scale_factors);
175+
auto scale_h = get_scale_value(scale_factors, 0);
176+
auto scale_w = get_scale_value(scale_factors, 1);
177+
return at::_upsample_bilinear2d_aa(input, osize, align_corners, scale_h, scale_w);
178+
}
179+
180+
Tensor _upsample_bilinear2d_aa_backward(
181+
const Tensor& grad_output,
182+
c10::optional<IntArrayRef> output_size,
183+
IntArrayRef input_size,
184+
bool align_corners,
185+
c10::optional<ArrayRef<double>> scale_factors) {
186+
auto osize = compute_output_size(input_size, output_size, scale_factors);
187+
auto scale_h = get_scale_value(scale_factors, 0);
188+
auto scale_w = get_scale_value(scale_factors, 1);
189+
return at::_upsample_bilinear2d_aa_backward(grad_output, osize, input_size, align_corners, scale_h, scale_w);
190+
}
191+
105192
DEFINE_DISPATCH(upsample_bilinear2d_kernel);
106193
DEFINE_DISPATCH(upsample_bilinear2d_backward_kernel);
194+
DEFINE_DISPATCH(_upsample_bilinear2d_aa_kernel);
195+
DEFINE_DISPATCH(_upsample_bilinear2d_aa_backward_kernel);
107196

108197
} // namespace native
109198
} // namespace at

0 commit comments

Comments
 (0)