Skip to content

Commit eba8f32

Browse files
committed
handle mixed-sign widening_shl
1 parent 19d4cdc commit eba8f32

File tree

3 files changed

+38
-11
lines changed

3 files changed

+38
-11
lines changed

src/DistributeShifts.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,28 @@ class DistributeShiftsAsMuls : public IRMutator {
143143
return IRMutator::visit(op);
144144
}
145145

146-
template<typename T>
147-
Expr visit_add_sub(const T *op) {
148-
if (multiply_adds) {
149-
Expr a, b;
150-
if (const Call *a_call = op->a.template as<Call>()) {
151-
if (a_call->is_intrinsic({Call::shift_left, Call::widening_shift_left})) {
152-
a = distribute_shift(a_call);
153-
}
146+
Expr handle_shift(const Expr &expr) {
147+
Expr ret;
148+
if (const Call *as_call = expr.template as<Call>()) {
149+
if (as_call->is_intrinsic({Call::shift_left, Call::widening_shift_left})) {
150+
ret = distribute_shift(as_call);
154151
}
155-
if (const Call *b_call = op->b.template as<Call>()) {
156-
if (b_call->is_intrinsic({Call::shift_left, Call::widening_shift_left})) {
157-
b = distribute_shift(b_call);
152+
} else if (const Cast *as_cast = expr.template as<Cast>()) {
153+
if (as_cast->is_reinterpret()) {
154+
ret = handle_shift(as_cast->value);
155+
if (ret.defined()) {
156+
ret = cast(as_cast->type, ret);
158157
}
159158
}
159+
}
160+
return ret;
161+
}
162+
163+
template<typename T>
164+
Expr visit_add_sub(const T *op) {
165+
if (multiply_adds) {
166+
Expr a = handle_shift(op->a);
167+
Expr b = handle_shift(op->b);
160168

161169
if (a.defined() && b.defined()) {
162170
return T::make(a, b);

src/FindIntrinsics.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,19 @@ class FindIntrinsics : public IRMutator {
879879
return mutate(result);
880880
}
881881

882+
// Try to lossless cast to uint.
883+
if (op->type.is_int() && bits >= 16) {
884+
Type uint_type = op->type.narrow().with_code(halide_type_uint);
885+
Expr a_narrow = lossless_cast(uint_type, op->args[0]);
886+
Expr b_narrow = lossless_cast(uint_type, op->args[1]);
887+
if (a_narrow.defined() && b_narrow.defined()) {
888+
Expr result = op->is_intrinsic(Call::shift_left) ? widening_shift_left(a_narrow, b_narrow) : widening_shift_right(a_narrow, b_narrow);
889+
internal_assert(result.type() != op->type);
890+
result = Cast::make(op->type, result);
891+
return mutate(result);
892+
}
893+
}
894+
882895
// Try to turn this into a rounding shift.
883896
Expr rounding_shift = to_rounding_shift(op);
884897
if (rounding_shift.defined()) {

test/correctness/simd_op_check_arm.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,28 +404,34 @@ class SimdOpCheckARM : public SimdOpCheckTest {
404404
check(arm32 ? "vmlal.s8" : "smlal", 8 * w, i16_1 + i16(i8_2) * 2);
405405
check(arm32 ? "vmlal.u8" : "umlal", 8 * w, u16_1 + u16(u8_2) * u8_3);
406406
check(arm32 ? "vmlal.u8" : "umlal", 8 * w, u16_1 + u16(u8_2) * 2);
407+
check(arm32 ? "vmlal.u8" : "umlal", 8 * w, i16_1 + i16(u8_2) * 2);
407408
check(arm32 ? "vmlal.s16" : "smlal", 4 * w, i32_1 + i32(i16_2) * i16_3);
408409
check(arm32 ? "vmlal.s16" : "smlal", 4 * w, i32_1 + i32(i16_2) * 2);
409410
check(arm32 ? "vmlal.u16" : "umlal", 4 * w, u32_1 + u32(u16_2) * u16_3);
410411
check(arm32 ? "vmlal.u16" : "umlal", 4 * w, u32_1 + u32(u16_2) * 2);
412+
check(arm32 ? "vmlal.u16" : "umlal", 4 * w, i32_1 + i32(u16_2) * 2);
411413
check(arm32 ? "vmlal.s32" : "smlal", 2 * w, i64_1 + i64(i32_2) * i32_3);
412414
check(arm32 ? "vmlal.s32" : "smlal", 2 * w, i64_1 + i64(i32_2) * 2);
413415
check(arm32 ? "vmlal.u32" : "umlal", 2 * w, u64_1 + u64(u32_2) * u32_3);
414416
check(arm32 ? "vmlal.u32" : "umlal", 2 * w, u64_1 + u64(u32_2) * 2);
417+
check(arm32 ? "vmlal.u32" : "umlal", 2 * w, i64_1 + i64(u32_2) * 2);
415418

416419
// VMLSL I - Multiply Subtract Long
417420
check(arm32 ? "vmlsl.s8" : "smlsl", 8 * w, i16_1 - i16(i8_2) * i8_3);
418421
check(arm32 ? "vmlsl.s8" : "smlsl", 8 * w, i16_1 - i16(i8_2) * 2);
419422
check(arm32 ? "vmlsl.u8" : "umlsl", 8 * w, u16_1 - u16(u8_2) * u8_3);
420423
check(arm32 ? "vmlsl.u8" : "umlsl", 8 * w, u16_1 - u16(u8_2) * 2);
424+
check(arm32 ? "vmlsl.u8" : "umlsl", 8 * w, i16_1 - i16(u8_2) * 2);
421425
check(arm32 ? "vmlsl.s16" : "smlsl", 4 * w, i32_1 - i32(i16_2) * i16_3);
422426
check(arm32 ? "vmlsl.s16" : "smlsl", 4 * w, i32_1 - i32(i16_2) * 2);
423427
check(arm32 ? "vmlsl.u16" : "umlsl", 4 * w, u32_1 - u32(u16_2) * u16_3);
424428
check(arm32 ? "vmlsl.u16" : "umlsl", 4 * w, u32_1 - u32(u16_2) * 2);
429+
check(arm32 ? "vmlsl.u16" : "umlsl", 4 * w, i32_1 - i32(u16_2) * 2);
425430
check(arm32 ? "vmlsl.s32" : "smlsl", 2 * w, i64_1 - i64(i32_2) * i32_3);
426431
check(arm32 ? "vmlsl.s32" : "smlsl", 2 * w, i64_1 - i64(i32_2) * 2);
427432
check(arm32 ? "vmlsl.u32" : "umlsl", 2 * w, u64_1 - u64(u32_2) * u32_3);
428433
check(arm32 ? "vmlsl.u32" : "umlsl", 2 * w, u64_1 - u64(u32_2) * 2);
434+
check(arm32 ? "vmlsl.u32" : "umlsl", 2 * w, i64_1 - i64(u32_2) * 2);
429435

430436
// VMOV X F, D Move Register or Immediate
431437
// This is for loading immediates, which we won't do in the inner loop anyway

0 commit comments

Comments
 (0)