Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 121 additions & 4 deletions aten/src/ATen/native/mps/kernels/Pooling.metal
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,65 @@ void avg_pool_3d_input_iter(
*output = value_sum / static_cast<T>(divisor);
}

template <typename T>
void avg_pool_backward_3d_input_iter(
device AtomicType_t<T>* grad_input,
constant T* grad_output,
constant int32_t* grad_input_sizes,
constant int32_t* grad_input_strides,
int32_t grad_input_leading_offset,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
bool count_include_pad,
bool has_divisor_override,
int32_t divisor_override) {
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds2 = get_avg_pool_input_iter_bounds<2>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);

auto divisor = has_divisor_override
? divisor_override
: (bounds0.count) * (bounds1.count) * (bounds2.count);
auto grad_val = *grad_output / static_cast<T>(divisor);
auto size12 = grad_input_sizes[1] * grad_input_sizes[2];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused variable

Suggested change
auto size12 = grad_input_sizes[1] * grad_input_sizes[2];


for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
auto offset0 = grad_input_strides[0] * i0;

for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
auto offset1 = grad_input_strides[1] * i1;

for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
auto offset2 = grad_input_strides[2] * i2;
auto pool_offset = offset0 + offset1 + offset2;

AtomicType<T>::atomic_add(
grad_input, grad_input_leading_offset + pool_offset, grad_val);
}
}
}
}

// Kernel computes one element of the output per kernel call.
template <typename T>
kernel void avg_pool(
Expand Down Expand Up @@ -500,6 +559,57 @@ kernel void avg_pool(
params.divisor_override);
}

template <typename T>
kernel void avg_pool_backward(
device AtomicType_t<T>* grad_input [[buffer(0)]],
constant T* grad_output [[buffer(1)]],
constant AvgPoolingParams<5>& params [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto grad_input_sizes = params.input_sizes.data();
auto grad_input_strides = params.input_strides.data();
auto grad_output_sizes = params.output_sizes.data();
auto grad_output_strides = params.output_strides.data();
auto kernel_size = params.kernel_size.data();
auto stride = params.stride.data();
auto padding = params.padding.data();
auto leading_dims = dims - pooling_dims;

// This buffer keeps track of the pooling dimension indices of this thread's
// element of the output. We need to fill it with the proper values below.
int32_t pooling_dim_indices[3];

PoolOffsets offsets = find_pool_offsets(
grad_output_sizes,
grad_output_strides,
/*indices_strides=*/nullptr,
grad_input_strides,
pooling_dim_indices,
dims,
leading_dims,
/*return_indices=*/false,
tid);

grad_output += offsets.output;
grad_input_sizes += leading_dims;
grad_input_strides += leading_dims;

avg_pool_backward_3d_input_iter<T>(
grad_input,
grad_output,
grad_input_sizes,
grad_input_strides,
offsets.input_leading,
pooling_dim_indices,
kernel_size,
stride,
padding,
params.count_include_pad,
params.has_divisor_override,
params.divisor_override);
}

#define REGISTER_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
Expand All @@ -521,13 +631,20 @@ kernel void avg_pool(
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);

#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
#define REGISTER_POOL_BACKWARD_OP(DTYPE) \
template [[host_name("max_pool_backward_" #DTYPE)]] \
kernel void max_pool_backward<DTYPE>( \
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output_ [[buffer(1)]], \
constant int64_t* grad_indices_ [[buffer(2)]], \
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_backward_" #DTYPE)]] \
kernel void avg_pool_backward<DTYPE>( \
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);

REGISTER_POOL_OP(float);
Expand All @@ -540,6 +657,6 @@ REGISTER_POOL_OP(char);
REGISTER_POOL_OP(uchar);
REGISTER_POOL_OP(bool);

REGISTER_MAX_POOL_BACKWARD_OP(float);
REGISTER_MAX_POOL_BACKWARD_OP(half);
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
REGISTER_POOL_BACKWARD_OP(float);
REGISTER_POOL_BACKWARD_OP(half);
REGISTER_POOL_BACKWARD_OP(bfloat);
81 changes: 81 additions & 0 deletions aten/src/ATen/native/mps/operations/Pooling.mm
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ATen/ops/avg_pool2d_backward.h>
#include <ATen/ops/avg_pool2d_backward_native.h>
#include <ATen/ops/avg_pool2d_native.h>
#include <ATen/ops/avg_pool3d_backward_native.h>
#include <ATen/ops/avg_pool3d_native.h>
#include <ATen/ops/max_pool2d_backward_native.h>
#include <ATen/ops/max_pool2d_native.h>
Expand Down Expand Up @@ -725,6 +726,64 @@ static void avg_pool_out_mps_template(const Tensor& output,
});
}

static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
const Tensor& input,
const Tensor& grad_output,
IntArrayRef _kernel_size,
IntArrayRef _stride,
IntArrayRef _padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, _, kernel_size, stride, padding, __] =
process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name);

const auto memory_format = input.suggest_memory_format();
grad_input.resize_(input.sizes(), memory_format);
grad_input.fill_(0);

id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = grad_output.numel();

AvgPoolingParams<5> params;

params.dims = dims;
params.pooling_dims = pooling_dims;
params.count_include_pad = count_include_pad;
params.has_divisor_override = divisor_override.has_value();
if (divisor_override.has_value()) {
params.divisor_override = safe_downcast<int32_t, int64_t>(divisor_override.value());
}

for (const auto dim : c10::irange(dims)) {
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_output.size(dim));
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(grad_output.stride(dim));
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_input.size(dim));
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(grad_input.stride(dim));
}

memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));

dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto PSO = lib.getPipelineStateForFunc("avg_pool_backward_" + scalarToMetalTypeString(input));

getMPSProfiler().beginProfileKernel(PSO, op_name, {grad_output});
[computeEncoder setComputePipelineState:PSO];
mtl_setArgs(computeEncoder, grad_input, grad_output, params);

mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
getMPSProfiler().endProfileKernel(PSO);
}
});
}

} // namespace mps

Tensor mps_max_pool2d(const Tensor& input,
Expand Down Expand Up @@ -1083,4 +1142,26 @@ Tensor max_unpooling3d_forward_mps(const Tensor& self,
"avg_pool3d");
}

TORCH_IMPL_FUNC(avg_pool3d_backward_out_mps)(const Tensor& grad_output,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const Tensor& grad_input) {
mps::avg_pool_backward_out_mps_template(grad_input,
input,
grad_output,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
/*pooling_dims=*/3,
"avg_pool3d_backward");
}

} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12378,6 +12378,7 @@
dispatch:
CPU: avg_pool3d_backward_out_cpu
CUDA: avg_pool3d_backward_out_cuda
MPS: avg_pool3d_backward_out_mps
MkldnnCPU: mkldnn_avg_pool3d_backward_out

- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
Expand Down
4 changes: 0 additions & 4 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9570,7 +9570,6 @@ def fn(a, b):
)
assertGeneratedKernelCountEqual(self, 0)

@xfail_if_mps_unimplemented
def test_avg_pool3d_backward(self):
def fn(a, b):
return aten.avg_pool3d_backward(
Expand All @@ -9592,7 +9591,6 @@ def fn(a, b):
],
)

@xfail_if_mps_unimplemented
@skip_if_halide # compiles for 5+ minutes
def test_avg_pool3d_backward2(self):
def fn(a, b):
Expand All @@ -9615,7 +9613,6 @@ def fn(a, b):
],
)

@xfail_if_mps_unimplemented
def test_avg_pool3d_backward3(self):
def fn(a, b):
return aten.avg_pool3d_backward(
Expand All @@ -9639,7 +9636,6 @@ def fn(a, b):
)
assertGeneratedKernelCountEqual(self, 1)

@xfail_if_mps_unimplemented
def test_avg_pool3d_backward4(self):
def fn(a, b):
return aten.avg_pool3d_backward(
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_angle(AtenTensorHandle self, Ate
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_baddbmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator);
Expand Down
5 changes: 0 additions & 5 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4064,11 +4064,6 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad
),
ModuleInfo(torch.nn.LocalResponseNorm,
module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
skips=(
# uses avg_pool3d which is not supported on MPS backend
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'),
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous_tensors'),
DecorateInfo(expectedFailureMPS, 'TestModule', 'test_non_contiguous'),)
),
ModuleInfo(torch.nn.LayerNorm,
module_inputs_func=module_inputs_torch_nn_LayerNorm,
Expand Down
1 change: 0 additions & 1 deletion torch/testing/_internal/common_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]:
"round": [torch.float16],
# topk fails with duplicate indices
"topk": [torch.float16],
"nn.functional.avg_pool3d": [torch.float32],
}

SKIPLIST_GRAD = {
Expand Down
Loading