@@ -77,6 +77,21 @@ class VectorizedN {
7777 return result;
7878 }
7979
80+ template <typename Op>
81+ inline VectorizedN<T, N> ternary_op (
82+ const VectorizedN<T, N>& other,
83+ const VectorizedN<T, N>& other2,
84+ Op op) const {
85+ VectorizedN<T, N> result;
86+ #ifndef _MSC_VER
87+ #pragma unroll
88+ #endif
89+ for (int i = 0 ; i < N; ++i) {
90+ result.values [i] = op (values[i], other.values [i], other2.values [i]);
91+ }
92+ return result;
93+ }
94+
8095 VectorizedN () = default ;
8196
8297 explicit VectorizedN (T val) {
@@ -89,7 +104,8 @@ class VectorizedN {
89104 VectorizedN (const Vectorized<T>& val) : values({val}) {}
90105
91106 template <int L = N, typename std::enable_if_t <L == 2 , int > = 0 >
92- VectorizedN (const Vectorized<T>& val_0, const Vectorized<T>& val_1) : values({val_0, val_1}) {}
107+ VectorizedN (const Vectorized<T>& val_0, const Vectorized<T>& val_1)
108+ : values({val_0, val_1}) {}
93109
94110 template <int L = N, typename std::enable_if_t <L == 1 , int > = 0 >
95111 inline operator Vectorized<T>() const {
@@ -110,7 +126,8 @@ class VectorizedN {
110126 const VectorizedN<T, N>& b) {
111127 VectorizedN<T, N> result;
112128 for (int i = 0 ; i < N; ++i) {
113- result.values [i] = Vectorized<T>::template blend<mask>(a.values [i], b.values [i]);
129+ result.values [i] =
130+ Vectorized<T>::template blend<mask>(a.values [i], b.values [i]);
114131 }
115132 return result;
116133 }
@@ -306,6 +323,20 @@ class VectorizedN {
306323 }); \
307324 }
308325
326+ #define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL (op ) \
327+ template <typename T, int N> \
328+ inline VectorizedN<T, N> op ( \
329+ const VectorizedN<T, N>& a, \
330+ const VectorizedN<T, N>& b, \
331+ const VectorizedN<T, N>& c) { \
332+ return a.ternary_op ( \
333+ b, \
334+ c, \
335+ [](const Vectorized<T>& a, \
336+ const Vectorized<T>& b, \
337+ const Vectorized<T>& c) { return op (a, b, c); }); \
338+ }
339+
309340#define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL (op ) \
310341 template <typename T, int N> \
311342 inline VectorizedN<T, N>& op ( \
@@ -326,9 +357,9 @@ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<)
326357VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (operator >>)
327358VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (maximum)
328359VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (minimum)
329- VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (fmadd)
330- VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (fmsub)
331- VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (clamp)
360+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL (fmadd)
361+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL (fmsub)
362+ VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL (clamp)
332363VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (clamp_max)
333364VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (clamp_min)
334365VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL (operator &)
@@ -357,5 +388,17 @@ inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN<T, N> acc_vec) {
357388 return vec_reduce_all (vec_fun, vec_result);
358389}
359390
391+ template <typename T, int N>
392+ std::ostream& operator <<(std::ostream& stream, const VectorizedN<T, N>& vec_n) {
393+ stream << " vec_n[" ;
394+ for (int i = 0 ; i < N; ++i) {
395+ if (i != 0 ) {
396+ stream << " , " ;
397+ }
398+ stream << vec_n[i];
399+ }
400+ stream << ' ]' ;
401+ return stream;
402+ }
360403} // namespace CPU_CAPABILITY
361404} // namespace at::vec
0 commit comments