Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions aten/src/ATen/native/mps/kernels/ActivationKernel.metal
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,37 @@ REGISTER_BINARY_ALPHA_OP(hardshrink_backward, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(hardshrink_backward, bfloat, bfloat, bfloat);
#endif

struct hardsigmoid_functor {
template <typename T>
inline T operator()(const T x) {
return static_cast<T>(min(max(x + 3.0f, .0f), 6.f) / 6.f);
}
};

struct hardsigmoid_backward_functor {
template <typename T>
inline T operator()(const T grad_output, const T self) {
constexpr T zero(0);
constexpr T neg_three(-3);
constexpr T three(3);

if (self < neg_three || self > three) {
return zero;
} else {
return static_cast<T>(grad_output * (1.0f / 6.0f));
}
Comment on lines +43 to +51
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
constexpr T zero(0);
constexpr T neg_three(-3);
constexpr T three(3);
if (self < neg_three || self > three) {
return zero;
} else {
return static_cast<T>(grad_output * (1.0f / 6.0f));
}
constexpr auto one_over_six = 1.0f / 6.0f;
return static_cast<T>(abs(float(self)) > 3.0f ? 0.0f : float(grad_output) * one_over_six)

}
};

REGISTER_UNARY_OP(hardsigmoid, float, float);
REGISTER_UNARY_OP(hardsigmoid, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
#endif

REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
#endif
101 changes: 0 additions & 101 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#include <ATen/ops/gelu_native.h>
#include <ATen/ops/glu_backward_native.h>
#include <ATen/ops/glu_native.h>
#include <ATen/ops/hardsigmoid_backward_native.h>
#include <ATen/ops/hardsigmoid_native.h>
#include <ATen/ops/hardswish_backward_native.h>
#include <ATen/ops/hardswish_native.h>
#include <ATen/ops/hardtanh_backward_native.h>
Expand Down Expand Up @@ -1752,105 +1750,6 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) {
}
}

TORCH_IMPL_FUNC(hardsigmoid_out_mps)(const Tensor& self, const Tensor& result) {
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;

TORCH_CHECK(self.is_mps());

// Empty output
if (result.numel() == 0)
return;

MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
std::string key = "hardsigmoid_out_mps:" + getTensorsStringKey({self});

auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* threeTensor = [mpsGraph constantWithScalar:3.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* sixTensor = [mpsGraph constantWithScalar:6.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* inputPlusThreeTensor = [mpsGraph additionWithPrimaryTensor:inputTensor
secondaryTensor:threeTensor
name:nil];

MPSGraphTensor* outputTensor = [mpsGraph clampWithTensor:inputPlusThreeTensor
minValueTensor:zeroTensor
maxValueTensor:sixTensor
name:nil];
outputTensor = [mpsGraph divisionWithPrimaryTensor:outputTensor secondaryTensor:sixTensor name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
});

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);

auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
}

TORCH_IMPL_FUNC(hardsigmoid_backward_out_mps)
(const Tensor& grad_output, const Tensor& self, const Tensor& grad_input) {
using namespace mps;
using CachedGraph = MPSUnaryGradCachedGraph;
TORCH_CHECK(self.is_mps());

// Empty output
if (grad_input.numel() == 0)
return;

MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
std::string key = "hardsigmoid_backward_out_mps:" + getTensorsStringKey({self});

auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* highTensor = [mpsGraph constantWithScalar:3.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* lowTensor = [mpsGraph constantWithScalar:-3.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* oneSixTensor = [mpsGraph constantWithScalar:1.0 / 6.0
shape:@[ @1 ]
dataType:getMPSDataType(self)];
MPSGraphTensor* inputLessThanHighPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
secondaryTensor:highTensor
name:nil];
MPSGraphTensor* inputGreaterThanLowPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor
secondaryTensor:lowTensor
name:nil];
MPSGraphTensor* inIntervalTensor = [mpsGraph logicalANDWithPrimaryTensor:inputLessThanHighPredicateTensor
secondaryTensor:inputGreaterThanLowPredicateTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradOutputTensor
secondaryTensor:oneSixTensor
name:nil];

outputTensor = [mpsGraph selectWithPredicateTensor:inIntervalTensor
truePredicateTensor:outputTensor
falsePredicateTensor:zeroTensor
name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->gradOutputTensor_ = gradOutputTensor;
newCachedGraph->gradInputTensor_ = outputTensor;
});

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);

// Create dictionary of inputs and outputs
auto feeds = dictionaryFromPlaceholders(selfPlaceholder, gradOutputPlaceholder);
auto results = dictionaryFromPlaceholders(gradInputPlaceholder);

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}
}

// -------------------------------------------------
// Hardtanh backward

Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/native/mps/operations/ActivationKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@ static void hardshrink_backward_kernel(TensorIteratorBase& iter, const Scalar& l
lib.exec_binary_kernel(iter, "hardshrink_backward", lambda);
}

static void hardsigmoid_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "hardsigmoid");
}

static void hardsigmoid_backward_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "hardsigmoid_backward");
}

REGISTER_DISPATCH(hardshrink_stub, hardshrink_kernel);
REGISTER_DISPATCH(shrink_backward_stub, hardshrink_backward_kernel);
REGISTER_DISPATCH(hardsigmoid_stub, hardsigmoid_kernel);
REGISTER_DISPATCH(hardsigmoid_backward_stub, hardsigmoid_backward_kernel);

} // namespace at::native
6 changes: 2 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11906,8 +11906,7 @@
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA: hardsigmoid_out
MPS: hardsigmoid_out_mps
CPU, CUDA, MPS: hardsigmoid_out
QuantizedCPU: hardsigmoid_out_quantized_cpu

- func: hardsigmoid(Tensor self) -> Tensor
Expand All @@ -11928,8 +11927,7 @@
structured_inherits: TensorIteratorBase
python_module: nn
dispatch:
CPU, CUDA: hardsigmoid_backward_out
MPS: hardsigmoid_backward_out_mps
CPU, CUDA, MPS: hardsigmoid_backward_out

- func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor
structured_delegate: hardsigmoid_backward.grad_input
Expand Down
Loading