Skip to content

Commit 7e51ac5

Browse files
Freey0facebook-github-bot
authored andcommitted
Port gcd to structured (#57624)
Summary: Pull Request resolved: #57624 Test Plan: Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D28224832 Pulled By: ezyang fbshipit-source-id: 30a8eba025c67d990103e49c03a396810f9d4006
1 parent 5044d9d commit 7e51ac5

File tree

5 files changed

+22
-18
lines changed

5 files changed

+22
-18
lines changed

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ TORCH_META_FUNC(atan2) (const Tensor& self, const Tensor& other) {
6767
build_binary_float_op(maybe_get_output(), self, other);
6868
}
6969

70+
// These are normal binary ops that preserve dtype
71+
#define CREATE_BINARY_META_FUNC(func) \
72+
TORCH_META_FUNC(func) (const Tensor& self, const Tensor& other) { \
73+
build_binary_op(maybe_get_output(), self, other); \
74+
}
75+
76+
CREATE_BINARY_META_FUNC(gcd);
77+
7078
} // namespace meta
7179

7280

@@ -198,6 +206,13 @@ TORCH_IMPL_FUNC(special_xlog1py_out) (const Tensor& self, const Tensor& other, c
198206
xlog1py_stub(device_type(), *this);
199207
}
200208

209+
#define CREATE_BINARY_TORCH_IMPL_FUNC(func) \
210+
TORCH_IMPL_FUNC(func##_out) (const Tensor& self, const Tensor& other, const Tensor& result) { \
211+
func##_stub(device_type(), *this); \
212+
}
213+
214+
CREATE_BINARY_TORCH_IMPL_FUNC(gcd);
215+
201216
Tensor special_xlog1py(const Scalar& x, const Tensor& y) {
202217
return at::special_xlog1py(wrapped_scalar_tensor(x), y);
203218
}
@@ -1062,21 +1077,6 @@ Tensor logaddexp2(const Tensor& self, const Tensor& other) {
10621077
return at::logaddexp2_out(result, self, other);
10631078
}
10641079

1065-
Tensor& gcd_out(const Tensor& self, const Tensor& other, Tensor& result) {
1066-
auto iter = TensorIterator::binary_op(result, self, other);
1067-
gcd_stub(iter.device_type(), iter);
1068-
return result;
1069-
}
1070-
1071-
Tensor gcd(const Tensor& self, const Tensor& other) {
1072-
Tensor result = at::empty({0}, self.options());
1073-
return at::gcd_out(result, self, other);
1074-
}
1075-
1076-
Tensor& gcd_(Tensor& self, const Tensor& other) {
1077-
return at::gcd_out(self, self, other);
1078-
}
1079-
10801080
Tensor& lcm_out(const Tensor& self, const Tensor& other, Tensor& result) {
10811081
auto iter = TensorIterator::binary_op(result, self, other);
10821082
lcm_stub(iter.device_type(), iter);

aten/src/ATen/native/BinaryOps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ DECLARE_DISPATCH(binary_fn, mse_stub);
8383
DECLARE_DISPATCH(binary_fn, fmod_stub);
8484
DECLARE_DISPATCH(binary_fn, logaddexp_stub);
8585
DECLARE_DISPATCH(binary_fn, logaddexp2_stub);
86-
DECLARE_DISPATCH(binary_fn, gcd_stub);
86+
DECLARE_DISPATCH(structured_binary_fn, gcd_stub);
8787
DECLARE_DISPATCH(binary_fn, lcm_stub);
8888
DECLARE_DISPATCH(binary_fn, hypot_stub);
8989
DECLARE_DISPATCH(binary_fn, igamma_stub);

aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ void logaddexp2_kernel(TensorIterator& iter) {
858858
});
859859
}
860860

861-
void gcd_kernel(TensorIterator& iter) {
861+
void gcd_kernel(TensorIteratorBase& iter) {
862862
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "gcd_cpu", [&]() {
863863
cpu_kernel(
864864
iter,

aten/src/ATen/native/cuda/GcdLcmKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
namespace at { namespace native {
1212

13-
void gcd_kernel_cuda(TensorIterator& iter) {
13+
void gcd_kernel_cuda(TensorIteratorBase& iter) {
1414
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "gcd_cuda", [&]() {
1515
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
1616
return calc_gcd(a, b);

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,13 +1977,17 @@
19771977
CPU: from_file
19781978

19791979
- func: gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
1980+
structured: True
1981+
structured_inherits: TensorIteratorBase
19801982
dispatch:
19811983
CPU, CUDA: gcd_out
19821984

19831985
- func: gcd(Tensor self, Tensor other) -> Tensor
1986+
structured_delegate: gcd.out
19841987
variants: function, method
19851988

19861989
- func: gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!)
1990+
structured_delegate: gcd.out
19871991
variants: function, method
19881992

19891993
- func: lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)

0 commit comments

Comments
 (0)