Skip to content

Commit 5e66bf5

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
Avoid COW materialize in nn.functional forward ops (3) (#122443)
Affected ops: * repeat * unfold * logsigmoid * pixel_shuffle/unshuffle * remaining norm ops Pull Request resolved: #122443 Approved by: https://github.com/ezyang
1 parent b6982bf commit 5e66bf5

File tree

14 files changed

+85
-90
lines changed

14 files changed

+85
-90
lines changed

aten/src/ATen/native/Repeat.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
template <typename index_t>
1818
static void compute_cpu(
19-
index_t* repeat_ptr,
20-
int64_t* cumsum_ptr,
19+
const index_t* repeat_ptr,
20+
const int64_t* cumsum_ptr,
2121
index_t* result_ptr,
2222
int64_t size,
2323
int64_t result_size) {

aten/src/ATen/native/Repeat.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace at::native {
1414

1515
template <
1616
typename index_t,
17-
void compute(index_t*, int64_t*, index_t*, int64_t, int64_t)>
17+
void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
1818
static inline Tensor repeat_interleave_common(
1919
const Tensor& repeats,
2020
c10::optional<int64_t> output_size) {
@@ -38,8 +38,8 @@ static inline Tensor repeat_interleave_common(
3838
}
3939

4040
Tensor result = at::empty({total}, repeats.options());
41-
index_t* repeat_ptr = repeats_.data_ptr<index_t>();
42-
int64_t* cumsum_ptr = cumsum.data_ptr<int64_t>();
41+
const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
42+
const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
4343
index_t* result_ptr = result.data_ptr<index_t>();
4444
compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
4545
return result;

aten/src/ATen/native/Unfold2d.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,25 @@
66

77
namespace at::native {
88

9-
using unfold2d_fn = void (*)(
9+
using unfold2d_copy_fn = void (*)(
10+
ScalarType dtype,
11+
void *finput,
12+
const void *input,
13+
int64_t kH,
14+
int64_t kW,
15+
int64_t dH,
16+
int64_t dW,
17+
int64_t padH,
18+
int64_t padW,
19+
int64_t n_input_plane,
20+
int64_t input_height,
21+
int64_t input_width,
22+
int64_t output_height,
23+
int64_t output_width,
24+
bool is_channels_last
25+
);
26+
27+
using unfold2d_acc_fn = void (*)(
1028
ScalarType dtype,
1129
void *finput,
1230
void *input,
@@ -24,7 +42,7 @@ using unfold2d_fn = void (*)(
2442
bool is_channels_last
2543
);
2644

27-
DECLARE_DISPATCH(unfold2d_fn, unfolded2d_copy_stub);
28-
DECLARE_DISPATCH(unfold2d_fn, unfolded2d_acc_stub);
45+
DECLARE_DISPATCH(unfold2d_copy_fn, unfolded2d_copy_stub);
46+
DECLARE_DISPATCH(unfold2d_acc_fn, unfolded2d_acc_stub);
2947

3048
} // namespace at::native

aten/src/ATen/native/cpu/Activation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const
3030
using Vec = Vectorized<scalar_t>;
3131
scalar_t* output_data = output.data_ptr<scalar_t>();
3232
scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
33-
scalar_t* input_data = input.data_ptr<scalar_t>();
33+
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
3434
parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
3535
int64_t size = end - begin;
3636
int64_t d = 0;
@@ -65,7 +65,7 @@ static void log_sigmoid_cpu_kernel(TensorBase &output, TensorBase &buffer, const
6565
using Vec = Vectorized<scalar_t>;
6666
scalar_t* output_data = output.data_ptr<scalar_t>();
6767
scalar_t* buffer_data = buffer.data_ptr<scalar_t>();
68-
scalar_t* input_data = input.data_ptr<scalar_t>();
68+
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
6969
parallel_for(0, input.numel(), 1, [&] (int64_t begin, int64_t end) {
7070
int64_t size = end - begin;
7171
int64_t d = 0;

aten/src/ATen/native/cpu/PixelShuffleKernel.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ void cpu_pixel_shuffle(
1717
TensorBase& output,
1818
const TensorBase& input,
1919
int64_t upscale_factor) {
20-
auto input_data = input.data_ptr<scalar_t>();
20+
auto input_data = input.const_data_ptr<scalar_t>();
2121
auto output_data = output.data_ptr<scalar_t>();
2222

2323
// [(B1...Bn), C, H, W] => [N, C, H, W]
@@ -59,7 +59,7 @@ void cpu_pixel_shuffle_channels_last(
5959
int64_t upscale_factor) {
6060
TORCH_CHECK(input.ndimension() == 4,
6161
"pixel shuffle with channels last format supports tensors with 4 dims");
62-
auto input_data = input.data_ptr<scalar_t>();
62+
auto input_data = input.const_data_ptr<scalar_t>();
6363
auto output_data = output.data_ptr<scalar_t>();
6464

6565
int64_t nbatch = input.size(0);
@@ -81,7 +81,7 @@ void cpu_pixel_shuffle_channels_last(
8181
data_index_init(begin, n, nbatch, h, height);
8282
for (const auto i : c10::irange(begin, end)) {
8383
for (const auto w : c10::irange(width)) {
84-
scalar_t* input_ptr = input_data + n * height * width * channels + h * width * channels + w * channels;
84+
const scalar_t* input_ptr = input_data + n * height * width * channels + h * width * channels + w * channels;
8585

8686
// step 1: transpose each channel lane
8787
// from: [c, s1*s2]
@@ -115,7 +115,7 @@ void cpu_pixel_unshuffle(
115115
TensorBase& output,
116116
const TensorBase& input,
117117
int64_t downscale_factor) {
118-
auto input_data = input.data_ptr<scalar_t>();
118+
auto input_data = input.const_data_ptr<scalar_t>();
119119
auto output_data = output.data_ptr<scalar_t>();
120120

121121
// [(B1...Bn), C, H, W] => [N, C, H, W]
@@ -158,7 +158,7 @@ void cpu_pixel_unshuffle_channels_last(
158158
int64_t downscale_factor) {
159159
TORCH_CHECK(input.ndimension() == 4,
160160
"pixel unshuffle with channels last format supports tensors with 4 dims");
161-
auto input_data = input.data_ptr<scalar_t>();
161+
auto input_data = input.const_data_ptr<scalar_t>();
162162
auto output_data = output.data_ptr<scalar_t>();
163163

164164
int64_t nbatch = input.size(0);

aten/src/ATen/native/cpu/Unfold2d.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ void unfolded2d_acc_kernel(
228228

229229
template <typename scalar_t>
230230
static void unfolded2d_copy(
231-
scalar_t* input_data,
231+
const scalar_t* input_data,
232232
scalar_t* finput_data,
233233
int64_t kH,
234234
int64_t kW,
@@ -256,7 +256,7 @@ static void unfolded2d_copy(
256256
nip * ((size_t)kH * kW * output_height * output_width) +
257257
kh * ((size_t)kW * output_height * output_width) +
258258
kw * ((size_t)output_height * output_width);
259-
scalar_t* src =
259+
const scalar_t* src =
260260
input_data + nip * ((size_t)input_height * input_width);
261261
if (padW > 0 || padH > 0) {
262262
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@@ -335,7 +335,7 @@ static void unfolded2d_copy(
335335

336336
template <typename scalar_t>
337337
static void unfolded2d_copy_channels_last(
338-
scalar_t* input_data,
338+
const scalar_t* input_data,
339339
scalar_t* finput_data,
340340
int64_t kH,
341341
int64_t kW,
@@ -355,7 +355,7 @@ static void unfolded2d_copy_channels_last(
355355

356356
for (const auto k C10_UNUSED: c10::irange(start, end)) {
357357
scalar_t* dst = finput_data + y * output_width * kH * kW * n_input_plane + x * kH * kW * n_input_plane;
358-
scalar_t* src = input_data;
358+
const scalar_t* src = input_data;
359359

360360
if (padW > 0 || padH > 0) {
361361
for (int64_t kh = 0; kh < kH; kh++) {
@@ -393,7 +393,7 @@ static void unfolded2d_copy_channels_last(
393393
void unfolded2d_copy_kernel(
394394
ScalarType dtype,
395395
void *finput_data,
396-
void *input_data,
396+
const void *input_data,
397397
int64_t kH,
398398
int64_t kW,
399399
int64_t dH,
@@ -415,7 +415,7 @@ void unfolded2d_copy_kernel(
415415
if (is_channels_last) {
416416
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_copy_channels_last", [&] {
417417
unfolded2d_copy_channels_last(
418-
static_cast<scalar_t*>(input_data),
418+
static_cast<const scalar_t*>(input_data),
419419
static_cast<scalar_t*>(finput_data),
420420
kH, kW,
421421
dH, dW,
@@ -429,7 +429,7 @@ void unfolded2d_copy_kernel(
429429
} else {
430430
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "unfolded2d_copy", [&] {
431431
unfolded2d_copy(
432-
static_cast<scalar_t*>(input_data),
432+
static_cast<const scalar_t*>(input_data),
433433
static_cast<scalar_t*>(finput_data),
434434
kH, kW,
435435
dH, dW,

aten/src/ATen/native/cpu/batch_norm_kernel.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ void batch_norm_cpu_collect_linear_and_constant_terms(
3434
const Tensor& save_mean, const Tensor& save_invstd,
3535
const Tensor& running_mean, const Tensor& running_var, bool train, double eps) {
3636

37-
const param_t* weight_data = weight.defined() ? weight.data_ptr<param_t>() : nullptr;
38-
const param_t* bias_data = bias.defined() ? bias.data_ptr<param_t>() : nullptr;
37+
const param_t* weight_data = weight.defined() ? weight.const_data_ptr<param_t>() : nullptr;
38+
const param_t* bias_data = bias.defined() ? bias.const_data_ptr<param_t>() : nullptr;
3939

40-
auto save_mean_a = conditional_accessor_1d<param_t>(save_mean);
41-
auto save_invstd_a = conditional_accessor_1d<param_t>(save_invstd);
42-
auto running_mean_a = conditional_accessor_1d<param_t>(running_mean);
43-
auto running_var_a = conditional_accessor_1d<param_t>(running_var);
40+
auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
41+
auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
42+
auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
43+
auto running_var_a = conditional_accessor_1d<const param_t>(running_var);
4444

4545
/// Collect the linear and constant terms regarding the input.
4646
/// output(n, c, h, w)
@@ -91,7 +91,7 @@ batch_norm_cpu_contiguous_impl(Tensor& output, const Tensor& input,
9191
save_mean, save_invstd, running_mean, running_var, train, eps);
9292

9393
scalar_t* output_data = output.data_ptr<scalar_t>();
94-
const scalar_t* input_data = input.data_ptr<scalar_t>();
94+
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
9595

9696
// Apply the linear terms to the input,
9797
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
@@ -143,7 +143,7 @@ batch_norm_cpu_channels_last_impl(Tensor& output, const Tensor& input,
143143
save_mean, save_invstd, running_mean, running_var, train, eps);
144144

145145
scalar_t* output_data = output.data_ptr<scalar_t>();
146-
const scalar_t* input_data = input.data_ptr<scalar_t>();
146+
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
147147

148148
// Apply the linear terms to the input,
149149
// output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
@@ -185,7 +185,7 @@ batch_norm_cpu_collect_stats_contiguous_impl(
185185
int64_t image_size = input.numel() / n_batch / n_channel;
186186
int64_t N = input.numel() / n_channel;
187187

188-
const scalar_t* input_data = input.data_ptr<scalar_t>();
188+
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
189189
scalar_t* mean_data = mean.data_ptr<scalar_t>();
190190
scalar_t* var_sum_data = var_sum.data_ptr<scalar_t>();
191191

@@ -229,7 +229,7 @@ batch_norm_cpu_collect_stats_channels_last_impl(
229229
int64_t n_channel = input.size(1);
230230
int64_t N = input.numel() / n_channel;
231231

232-
const scalar_t* input_data = input.data_ptr<scalar_t>();
232+
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
233233
scalar_t* mean_data = mean.data_ptr<scalar_t>();
234234
scalar_t* var_sum_data = var_sum.data_ptr<scalar_t>();
235235

aten/src/ATen/native/cpu/group_norm_kernel.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ void GroupNormKernelImplInternal(
4343
TORCH_CHECK(!beta.defined() || beta.numel() == C);
4444
const int64_t G = group;
4545
const int64_t D = C / G;
46-
const T* X_data = X.data_ptr<T>();
47-
const PT* gamma_data = gamma.defined() ? gamma.data_ptr<PT>() : nullptr;
48-
const PT* beta_data = beta.defined() ? beta.data_ptr<PT>() : nullptr;
46+
const T* X_data = X.const_data_ptr<T>();
47+
const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
48+
const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
4949
T* Y_data = Y.data_ptr<T>();
5050
PT* mean_data = mean.data_ptr<PT>();
5151
PT* rstd_data = rstd.data_ptr<PT>();
@@ -298,9 +298,9 @@ void GroupNormKernelImplChannelsLastInternal(
298298
TORCH_CHECK(!beta.defined() || beta.numel() == C);
299299
const int64_t G = group;
300300
const int64_t D = C / G;
301-
const T* X_data = X.data_ptr<T>();
302-
const PT* gamma_data = gamma.defined() ? gamma.data_ptr<PT>() : nullptr;
303-
const PT* beta_data = beta.defined() ? beta.data_ptr<PT>() : nullptr;
301+
const T* X_data = X.const_data_ptr<T>();
302+
const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
303+
const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
304304
T* Y_data = Y.data_ptr<T>();
305305
PT* mean_data = mean.data_ptr<PT>();
306306
PT* rstd_data = rstd.data_ptr<PT>();

aten/src/ATen/native/cuda/Activation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_cuda(const Tensor& input, T
8080
// NOTE: buffer is only used by CPU dispatch, we just ignore it here
8181
auto iter = TensorIteratorConfig()
8282
.add_output(result)
83-
.add_input(input)
83+
.add_const_input(input)
8484
.build();
8585
launch_log_sigmoid_forward_kernel(iter);
8686
return std::forward_as_tuple(result, buffer);

aten/src/ATen/native/cuda/Normalization.cuh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,12 @@ __device__ __forceinline__ void welford_merge_block_vertical(C& count,
210210

211211
template <typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, bool train, typename index_t>
212212
__global__ void batch_norm_transform_input_kernel(
213-
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
213+
const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
214214
GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> output,
215215
const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> mean_,
216216
const GenericPackedTensorAccessor<typename std::conditional<train, stat_accscalar_t, stat_scalar_t>::type, 1, RestrictPtrTraits, index_t> var_or_invstd,
217-
const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
218-
const GenericPackedTensorAccessor<stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
217+
const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> weight,
218+
const GenericPackedTensorAccessor<const stat_scalar_t, 1, RestrictPtrTraits, index_t> bias,
219219
stat_accscalar_t epsilon) {
220220

221221
index_t plane = blockIdx.x;
@@ -267,7 +267,7 @@ struct Var {
267267

268268
template <typename VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t>
269269
__global__ void batch_norm_collect_statistics_kernel(
270-
const GenericPackedTensorAccessor<input_scalar_t, 3, RestrictPtrTraits, index_t> input,
270+
const GenericPackedTensorAccessor<const input_scalar_t, 3, RestrictPtrTraits, index_t> input,
271271
const stat_accscalar_t epsilon,
272272
const stat_accscalar_t momentum,
273273
GenericPackedTensorAccessor<stat_accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
@@ -582,7 +582,7 @@ __global__ void batch_norm_backward_elemt_kernel(
582582
template <typename scalar_t, int64_t dim, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
583583
static GenericPackedTensorAccessor<scalar_t, dim, PtrTraits, index_t> get_packed_accessor(
584584
const Tensor& t, c10::string_view var_name) {
585-
constexpr auto expect_type = c10::CppTypeToScalarType<scalar_t>::value;
585+
constexpr auto expect_type = c10::CppTypeToScalarType<typename std::remove_const<scalar_t>::type>::value;
586586
const auto actual_type = t.scalar_type();
587587
TORCH_CHECK(actual_type == expect_type, "Expected ", var_name,
588588
" to have type ", expect_type, " but got ", actual_type);
@@ -670,7 +670,7 @@ void batch_norm_stats_cuda_template(
670670
resize_output(out_mean, {n_input});
671671
resize_output(out_invstd, {n_input});
672672
auto input = get_packed_accessor<
673-
scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
673+
const scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
674674
TORCH_INTERNAL_ASSERT(out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
675675
out_invstd.sizes()[0]);
676676
TORCH_INTERNAL_ASSERT(out_mean.dim() == 1 && out_mean.is_contiguous() &&
@@ -700,13 +700,13 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_,
700700
auto output_reshaped = output_.view({input_.size(0), input_.size(1), -1});
701701

702702
auto input = get_packed_accessor<
703-
input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
703+
const input_scalar_t, 3, RestrictPtrTraits, index_t>(input_reshaped, "input");
704704
auto output = get_packed_accessor<
705705
input_scalar_t, 3, RestrictPtrTraits, index_t>(output_reshaped, "output");
706706
auto weight = packed_accessor_or_dummy<
707-
stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
707+
const stat_scalar_t, 1, RestrictPtrTraits, index_t>(weight_, "weight");
708708
auto bias = packed_accessor_or_dummy<
709-
stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
709+
const stat_scalar_t, 1, RestrictPtrTraits, index_t>(bias_, "bias");
710710
auto mean = packed_accessor_or_dummy<
711711
stat_accscalar_t, 1, RestrictPtrTraits, index_t>(mean_, "mean");
712712
auto invstd = packed_accessor_or_dummy<

0 commit comments

Comments
 (0)