Skip to content

Commit e0ae6be

Browse files
Update
[ghstack-poisoned]
1 parent d327cad commit e0ae6be

File tree

2 files changed

+29
-46
lines changed

2 files changed

+29
-46
lines changed

aten/src/ATen/native/mps/kernels/ActivationKernel.metal

Lines changed: 23 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,20 @@ struct hardshrink_functor {
1111
}
1212
};
1313

14-
struct hardshrink_backward_functor {
14+
struct softshrink_functor {
15+
template <typename T>
16+
inline T operator()(const T x, const T lambda) {
17+
if (x > lambda) {
18+
return x - lambda;
19+
} else if (x < -lambda) {
20+
return x + lambda;
21+
} else {
22+
return T(0);
23+
}
24+
}
25+
};
26+
27+
struct shrink_backward_functor {
1528
template <typename T>
1629
inline T operator()(const T grad_output, const T x, const T lambda) {
1730
return (x >= -lambda && x <= lambda) ? T(0) : grad_output;
@@ -24,10 +37,16 @@ REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
2437
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
2538
#endif
2639

27-
REGISTER_BINARY_ALPHA_OP(hardshrink_backward, float, float, float);
28-
REGISTER_BINARY_ALPHA_OP(hardshrink_backward, half, half, half);
40+
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
41+
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
42+
#if __METAL_VERSION__ >= 310
43+
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
44+
#endif
45+
46+
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
47+
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
2948
#if __METAL_VERSION__ >= 310
30-
REGISTER_BINARY_ALPHA_OP(hardshrink_backward, bfloat, bfloat, bfloat);
49+
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
3150
#endif
3251

3352
struct hardsigmoid_functor {
@@ -128,39 +147,3 @@ REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
128147
#if __METAL_VERSION__ >= 310
129148
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
130149
#endif
131-
132-
struct softshrink_functor {
133-
template <typename T>
134-
inline T operator()(const T x, const T lambda) {
135-
if (x > lambda) {
136-
return x - lambda;
137-
} else if (x < -lambda) {
138-
return x + lambda;
139-
} else {
140-
return T(0);
141-
}
142-
}
143-
};
144-
145-
struct softshrink_backward_functor {
146-
template <typename T>
147-
inline T operator()(const T grad_output, const T self, const T lambda) {
148-
if (self > lambda || self < -lambda) {
149-
return grad_output;
150-
} else {
151-
return T(0);
152-
}
153-
}
154-
};
155-
156-
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
157-
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
158-
#if __METAL_VERSION__ >= 310
159-
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
160-
#endif
161-
162-
REGISTER_BINARY_ALPHA_OP(softshrink_backward, float, float, float);
163-
REGISTER_BINARY_ALPHA_OP(softshrink_backward, half, half, half);
164-
#if __METAL_VERSION__ >= 310
165-
REGISTER_BINARY_ALPHA_OP(softshrink_backward, bfloat, bfloat, bfloat);
166-
#endif

aten/src/ATen/native/mps/operations/ActivationKernel.mm

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ static void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0
1717
lib.exec_unary_kernel(iter, "hardshrink", lambda);
1818
}
1919

20+
static void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) {
21+
lib.exec_unary_kernel(iter, "softshrink", lambda);
22+
}
23+
2024
static void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) {
21-
lib.exec_binary_kernel(iter, "hardshrink_backward", lambda);
25+
lib.exec_binary_kernel(iter, "shrink_backward", lambda);
2226
}
2327

2428
static void hardsigmoid_kernel(TensorIteratorBase& iter) {
@@ -45,18 +49,14 @@ static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& n
4549
lib.exec_binary_kernel(iter, "leaky_relu_backward", negative_slope);
4650
}
4751

48-
static void softshrink_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) {
49-
lib.exec_unary_kernel(iter, "softshrink", lambda);
50-
}
51-
5252
REGISTER_DISPATCH(hardshrink_stub, hardshrink_kernel);
53+
REGISTER_DISPATCH(softshrink_stub, softshrink_kernel);
5354
REGISTER_DISPATCH(shrink_backward_stub, shrink_backward_kernel);
5455
REGISTER_DISPATCH(hardsigmoid_stub, hardsigmoid_kernel);
5556
REGISTER_DISPATCH(hardsigmoid_backward_stub, hardsigmoid_backward_kernel);
5657
REGISTER_DISPATCH(hardswish_stub, hardswish_kernel);
5758
REGISTER_DISPATCH(hardswish_backward_stub, hardswish_backward_kernel);
5859
REGISTER_DISPATCH(leaky_relu_stub, leaky_relu_kernel);
5960
REGISTER_DISPATCH(leaky_relu_backward_stub, leaky_relu_backward_kernel);
60-
REGISTER_DISPATCH(softshrink_stub, softshrink_kernel);
6161

6262
} // namespace at::native

0 commit comments

Comments
 (0)