Skip to content

Commit 034b105

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE][Ez]: Add NT unary op macro (#140213)
* Adds a macro to simplify adding more unary ops to NT. * Adds sqrt support to NT Pull Request resolved: #140213 Approved by: https://github.com/jbschlosser
1 parent 069a710 commit 034b105

File tree

3 files changed

+26
-52
lines changed

3 files changed

+26
-52
lines changed

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,7 +1791,7 @@
17911791
variants: function, method
17921792
structured_delegate: cos.out
17931793
dispatch:
1794-
NestedTensorCPU, NestedTensorCUDA: cos_nested
1794+
NestedTensorCPU, NestedTensorCUDA: NestedTensor_cos
17951795
tags: [core, pointwise]
17961796

17971797
- func: cos_(Tensor(a!) self) -> Tensor(a!)
@@ -5321,7 +5321,7 @@
53215321
dispatch:
53225322
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr
53235323
SparseCPU, SparseCUDA: sin_sparse
5324-
NestedTensorCPU, NestedTensorCUDA: sin_nested
5324+
NestedTensorCPU, NestedTensorCUDA: NestedTensor_sin
53255325
tags: [core, pointwise]
53265326

53275327
- func: sin_(Tensor(a!) self) -> Tensor(a!)
@@ -5819,6 +5819,7 @@
58195819
structured_delegate: sqrt.out
58205820
variants: function, method
58215821
dispatch:
5822+
NestedTensorCPU, NestedTensorCUDA: NestedTensor_sqrt
58225823
SparseCPU, SparseCUDA: sqrt_sparse
58235824
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr
58245825
tags: [core, pointwise]

aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp

Lines changed: 22 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,28 @@
1515

1616
namespace at::native {
1717

18-
Tensor NestedTensor_abs(const Tensor& self) {
19-
return map_nt(self, at::abs);
20-
}
18+
#define DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(op_name) \
19+
Tensor NestedTensor_##op_name(const Tensor& self) { \
20+
return map_nt(self, at::op_name); \
21+
}
22+
23+
// Use the macro to define operations concisely
24+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(abs)
25+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(sgn)
26+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(logical_not)
27+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(isinf)
28+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(isposinf)
29+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(isneginf)
30+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(isnan)
31+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(relu)
32+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(silu)
33+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(sin)
34+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(sqrt)
35+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(cos)
36+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(neg)
37+
DEFINE_TORCH_NESTED_TENSOR_UNARY_OP(tanh)
38+
39+
#undef DEFINE_TORCH_NESTED_TENSOR_UNARY_OP
2140

2241
Tensor& NestedTensor_abs_(Tensor& self) {
2342
auto self_ptr = get_nested_tensor_impl(self);
@@ -96,10 +115,6 @@ Tensor& NestedTensor_where_out(const Tensor& condition, const Tensor& self, cons
96115
return out;
97116
}
98117

99-
Tensor NestedTensor_sgn(const Tensor& self) {
100-
return map_nt(self, at::sgn);
101-
}
102-
103118
Tensor& NestedTensor_sgn_(Tensor& self) {
104119
auto self_ptr = get_nested_tensor_impl(self);
105120
check_numel_equals_buffer_size(self_ptr);
@@ -116,25 +131,6 @@ Tensor& NestedTensor_logical_not_(Tensor& self){
116131
return self;
117132
}
118133

119-
Tensor NestedTensor_logical_not(const Tensor& self) {
120-
return map_nt(self, at::logical_not);
121-
}
122-
123-
Tensor NestedTensor_isinf(const Tensor& self) {
124-
return map_nt(self, at::isinf);
125-
}
126-
127-
Tensor NestedTensor_isposinf(const Tensor& self) {
128-
return map_nt(self, at::isposinf);
129-
}
130-
131-
Tensor NestedTensor_isneginf(const Tensor& self) {
132-
return map_nt(self, at::isneginf);
133-
}
134-
135-
Tensor NestedTensor_isnan(const Tensor& self) {
136-
return map_nt(self, at::isnan);
137-
}
138134

139135
Tensor& NestedTensor_relu_(Tensor& self) {
140136
auto self_ptr = get_nested_tensor_impl(self);
@@ -144,10 +140,6 @@ Tensor& NestedTensor_relu_(Tensor& self) {
144140
return self;
145141
}
146142

147-
Tensor NestedTensor_relu(const Tensor& self) {
148-
return map_nt(self, at::relu);
149-
}
150-
151143
Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) {
152144
auto self_ptr = get_nested_tensor_impl(self);
153145
check_numel_equals_buffer_size(self_ptr);
@@ -172,10 +164,6 @@ Tensor& NestedTensor_tanh_(Tensor& self) {
172164
return self;
173165
}
174166

175-
Tensor NestedTensor_tanh(const Tensor& self) {
176-
return map_nt(self, at::tanh);
177-
}
178-
179167
Tensor& NestedTensor_neg_(Tensor& self) {
180168
auto self_ptr = get_nested_tensor_impl(self);
181169
check_numel_equals_buffer_size(self_ptr);
@@ -184,20 +172,12 @@ Tensor& NestedTensor_neg_(Tensor& self) {
184172
return self;
185173
}
186174

187-
Tensor NestedTensor_neg(const Tensor& self) {
188-
return map_nt(self, at::neg);
189-
}
190-
191175
Tensor& zero_nested_(Tensor& self) {
192176
const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
193177
self_buf.fill_(0);
194178
return self;
195179
}
196180

197-
Tensor NestedTensor_silu(const Tensor& self){
198-
return map_nt(self, at::silu);
199-
}
200-
201181
Tensor& NestedTensor_silu_(Tensor& self){
202182
auto self_ptr = get_nested_tensor_impl(self);
203183
check_numel_equals_buffer_size(self_ptr);
@@ -206,14 +186,6 @@ Tensor& NestedTensor_silu_(Tensor& self){
206186
return self;
207187
}
208188

209-
Tensor sin_nested(const Tensor& self) {
210-
return map_nt(self, at::sin);
211-
}
212-
213-
Tensor cos_nested(const Tensor& self) {
214-
return map_nt(self, at::cos);
215-
}
216-
217189
Tensor _pin_memory_nested(const Tensor& self, std::optional<Device> device) {
218190
auto* nt_input = get_nested_tensor_impl(self);
219191
const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();

test/test_nestedtensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,7 @@ def test_nested_tensor_indexing(self, device, dtype):
12891289
subtest(torch.isposinf, name="isposinf"),
12901290
subtest(torch.isneginf, name="isneginf"),
12911291
subtest(torch.isnan, name="isnan"),
1292+
subtest(torch.sqrt, name="sqrt"),
12921293
],
12931294
)
12941295
def test_unary_funcs(self, device, func):

0 commit comments

Comments
 (0)