1515
1616namespace at ::native {
1717
18- Tensor NestedTensor_abs (const Tensor& self) {
19- return map_nt (self, at::abs);
20- }
18+ #define DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (op_name ) \
19+ Tensor NestedTensor_##op_name(const Tensor& self) { \
20+ return map_nt (self, at::op_name); \
21+ }
22+
23+ // Use the macro to define operations concisely
24+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (abs)
25+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (sgn)
26+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (logical_not)
27+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (isinf)
28+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (isposinf)
29+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (isneginf)
30+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (isnan)
31+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (relu)
32+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (silu)
33+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (sin)
34+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (sqrt)
35+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (cos)
36+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (neg)
37+ DEFINE_TORCH_NESTED_TENSOR_UNARY_OP (tanh)
38+
39+ #undef DEFINE_TORCH_NESTED_TENSOR_UNARY_OP
2140
2241Tensor& NestedTensor_abs_ (Tensor& self) {
2342 auto self_ptr = get_nested_tensor_impl (self);
@@ -96,10 +115,6 @@ Tensor& NestedTensor_where_out(const Tensor& condition, const Tensor& self, cons
96115 return out;
97116}
98117
99- Tensor NestedTensor_sgn (const Tensor& self) {
100- return map_nt (self, at::sgn);
101- }
102-
103118Tensor& NestedTensor_sgn_ (Tensor& self) {
104119 auto self_ptr = get_nested_tensor_impl (self);
105120 check_numel_equals_buffer_size (self_ptr);
@@ -116,25 +131,6 @@ Tensor& NestedTensor_logical_not_(Tensor& self){
116131 return self;
117132}
118133
119- Tensor NestedTensor_logical_not (const Tensor& self) {
120- return map_nt (self, at::logical_not);
121- }
122-
123- Tensor NestedTensor_isinf (const Tensor& self) {
124- return map_nt (self, at::isinf);
125- }
126-
127- Tensor NestedTensor_isposinf (const Tensor& self) {
128- return map_nt (self, at::isposinf);
129- }
130-
131- Tensor NestedTensor_isneginf (const Tensor& self) {
132- return map_nt (self, at::isneginf);
133- }
134-
135- Tensor NestedTensor_isnan (const Tensor& self) {
136- return map_nt (self, at::isnan);
137- }
138134
139135Tensor& NestedTensor_relu_ (Tensor& self) {
140136 auto self_ptr = get_nested_tensor_impl (self);
@@ -144,10 +140,6 @@ Tensor& NestedTensor_relu_(Tensor& self) {
144140 return self;
145141}
146142
147- Tensor NestedTensor_relu (const Tensor& self) {
148- return map_nt (self, at::relu);
149- }
150-
151143Tensor& NestedTensor_gelu_ (Tensor& self, c10::string_view approximate) {
152144 auto self_ptr = get_nested_tensor_impl (self);
153145 check_numel_equals_buffer_size (self_ptr);
@@ -172,10 +164,6 @@ Tensor& NestedTensor_tanh_(Tensor& self) {
172164 return self;
173165}
174166
175- Tensor NestedTensor_tanh (const Tensor& self) {
176- return map_nt (self, at::tanh);
177- }
178-
179167Tensor& NestedTensor_neg_ (Tensor& self) {
180168 auto self_ptr = get_nested_tensor_impl (self);
181169 check_numel_equals_buffer_size (self_ptr);
@@ -184,20 +172,12 @@ Tensor& NestedTensor_neg_(Tensor& self) {
184172 return self;
185173}
186174
187- Tensor NestedTensor_neg (const Tensor& self) {
188- return map_nt (self, at::neg);
189- }
190-
191175Tensor& zero_nested_ (Tensor& self) {
192176 const auto & self_buf = get_nested_tensor_impl (self)->get_buffer ();
193177 self_buf.fill_ (0 );
194178 return self;
195179}
196180
197- Tensor NestedTensor_silu (const Tensor& self){
198- return map_nt (self, at::silu);
199- }
200-
201181Tensor& NestedTensor_silu_ (Tensor& self){
202182 auto self_ptr = get_nested_tensor_impl (self);
203183 check_numel_equals_buffer_size (self_ptr);
@@ -206,14 +186,6 @@ Tensor& NestedTensor_silu_(Tensor& self){
206186 return self;
207187}
208188
209- Tensor sin_nested (const Tensor& self) {
210- return map_nt (self, at::sin);
211- }
212-
213- Tensor cos_nested (const Tensor& self) {
214- return map_nt (self, at::cos);
215- }
216-
217189Tensor _pin_memory_nested (const Tensor& self, std::optional<Device> device) {
218190 auto * nt_input = get_nested_tensor_impl (self);
219191 const auto & input_buffer = nt_input->get_unsafe_storage_as_tensor ();
0 commit comments