Skip to content

Commit 0b06741

Browse files
ptrblckfacebook-github-bot
authored andcommitted
Fix strict aliasing rule violation in bitwise_binary_op (#66194)
Summary: Fixes #66119 Failure on ARM Neoverse N1 before this PR: ``` ====================================================================== FAIL: test_bitwise_ops_cpu_int16 (__main__.TestBinaryUfuncsCPU) ---------------------------------------------------------------------- Traceback (most recent call last): File "/opt/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 373, in instantiated_test result = test(self, **param_kwargs) File "test_binary_ufuncs.py", line 315, in test_bitwise_ops self.assertEqual(op(a, b), op(a_np, b_np)) File "/opt/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 1633, in assertEqual self.assertEqual( File "/opt/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 1611, in assertEqual super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) AssertionError: False is not true : Tensors failed to compare as equal!Found 176 different element(s) (out of 225), with the greatest difference of 21850 (-21846 vs. 4) occuring at index (0, 2). ====================================================================== FAIL: test_bitwise_ops_cpu_int32 (__main__.TestBinaryUfuncsCPU) ---------------------------------------------------------------------- Traceback (most recent call last): File "/opt/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 373, in instantiated_test result = test(self, **param_kwargs) File "test_binary_ufuncs.py", line 315, in test_bitwise_ops self.assertEqual(op(a, b), op(a_np, b_np)) File "/opt/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 1633, in assertEqual self.assertEqual( File "/opt/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 1611, in assertEqual super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg)) AssertionError: False is not true : Tensors failed to compare as equal!Found 188 different element(s) (out of 225), with the greatest difference of 1335341061 (-1335341056 vs. 5) occuring at index (14, 8). ---------------------------------------------------------------------- ``` which passes now. CC malfet ezyang Pull Request resolved: #66194 Reviewed By: dagitses, bdhirsh, ngimel Differential Revision: D31430274 Pulled By: malfet fbshipit-source-id: bcf1c9d584c02eff328dd5b1f7af064fac5942c9
1 parent d176c82 commit 0b06741

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

aten/src/ATen/cpu/vec/vec_base.h

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ struct Vectorized {
147147
inline operator T*() {
148148
return values;
149149
}
150+
// Return the values as char* for type punning
151+
auto as_bytes() const -> const char* {
152+
return reinterpret_cast<const char*>(values);
153+
}
150154
template <int64_t mask_>
151155
static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
152156
int64_t mask = mask_;
@@ -736,15 +740,33 @@ inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
736740

737741
#else
738742

743+
template <typename T>
744+
auto load(char const* data) -> T {
745+
T ret;
746+
std::memcpy(&ret, data, sizeof(ret));
747+
return ret;
748+
}
749+
739750
template<class T, typename Op>
740751
static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
741752
static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
742753
__at_align__ intmax_t buffer[element_no];
743-
const intmax_t *a_ptr = reinterpret_cast<const intmax_t*>((const T*) a);
744-
const intmax_t *b_ptr = reinterpret_cast<const intmax_t*>((const T*) b);
745-
for (uint32_t i = 0U; i < element_no; ++ i) {
746-
buffer[i] = op(a_ptr[i], b_ptr[i]);
747-
}
754+
static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
755+
static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)");
756+
// We should be using memcpy in order to respect the strict aliasing rule
757+
// see: https://github.com/pytorch/pytorch/issues/66119
758+
// Using char* is defined in the C11 standard 6.5 Expression paragraph 7
759+
// (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
760+
const auto* a_data = a.as_bytes();
761+
const auto* b_data = b.as_bytes();
762+
// load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
763+
for (auto& out : buffer) {
764+
out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
765+
a_data += sizeof(intmax_t);
766+
b_data += sizeof(intmax_t);
767+
}
768+
assert(a_data == a.as_bytes() + sizeof(a));
769+
assert(b_data == b.as_bytes() + sizeof(b));
748770
return Vectorized<T>::loadu(buffer);
749771
}
750772

0 commit comments

Comments
 (0)