@@ -11,7 +11,20 @@ struct hardshrink_functor {
1111 }
1212};
1313
14- struct hardshrink_backward_functor {
14+ struct softshrink_functor {
15+ template <typename T>
16+ inline T operator ()(const T x, const T lambda) {
17+ if (x > lambda) {
18+ return x - lambda;
19+ } else if (x < -lambda) {
20+ return x + lambda;
21+ } else {
22+ return T (0 );
23+ }
24+ }
25+ };
26+
27+ struct shrink_backward_functor {
1528 template <typename T>
1629 inline T operator ()(const T grad_output, const T x, const T lambda) {
1730 return (x >= -lambda && x <= lambda) ? T (0 ) : grad_output;
@@ -24,10 +37,16 @@ REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
2437REGISTER_UNARY_ALPHA_OP (hardshrink, bfloat, bfloat, bfloat);
2538#endif
2639
27- REGISTER_BINARY_ALPHA_OP (hardshrink_backward, float , float , float );
28- REGISTER_BINARY_ALPHA_OP (hardshrink_backward, half, half, half);
40+ REGISTER_UNARY_ALPHA_OP (softshrink, float , float , float );
41+ REGISTER_UNARY_ALPHA_OP (softshrink, half, half, half);
42+ #if __METAL_VERSION__ >= 310
43+ REGISTER_UNARY_ALPHA_OP (softshrink, bfloat, bfloat, bfloat);
44+ #endif
45+
46+ REGISTER_BINARY_ALPHA_OP (shrink_backward, float , float , float );
47+ REGISTER_BINARY_ALPHA_OP (shrink_backward, half, half, half);
2948#if __METAL_VERSION__ >= 310
30- REGISTER_BINARY_ALPHA_OP (hardshrink_backward , bfloat, bfloat, bfloat);
49+ REGISTER_BINARY_ALPHA_OP (shrink_backward , bfloat, bfloat, bfloat);
3150#endif
3251
3352struct hardsigmoid_functor {
@@ -128,39 +147,3 @@ REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
128147#if __METAL_VERSION__ >= 310
129148REGISTER_BINARY_ALPHA_OP (leaky_relu_backward, bfloat, bfloat, bfloat);
130149#endif
131-
132- struct softshrink_functor {
133- template <typename T>
134- inline T operator ()(const T x, const T lambda) {
135- if (x > lambda) {
136- return x - lambda;
137- } else if (x < -lambda) {
138- return x + lambda;
139- } else {
140- return T (0 );
141- }
142- }
143- };
144-
145- struct softshrink_backward_functor {
146- template <typename T>
147- inline T operator ()(const T grad_output, const T self, const T lambda) {
148- if (self > lambda || self < -lambda) {
149- return grad_output;
150- } else {
151- return T (0 );
152- }
153- }
154- };
155-
156- REGISTER_UNARY_ALPHA_OP (softshrink, float , float , float );
157- REGISTER_UNARY_ALPHA_OP (softshrink, half, half, half);
158- #if __METAL_VERSION__ >= 310
159- REGISTER_UNARY_ALPHA_OP (softshrink, bfloat, bfloat, bfloat);
160- #endif
161-
162- REGISTER_BINARY_ALPHA_OP (softshrink_backward, float , float , float );
163- REGISTER_BINARY_ALPHA_OP (softshrink_backward, half, half, half);
164- #if __METAL_VERSION__ >= 310
165- REGISTER_BINARY_ALPHA_OP (softshrink_backward, bfloat, bfloat, bfloat);
166- #endif
0 commit comments