[X86] combine widening shl + adjacent addition into VPMADDWD#179326
[X86] combine widening shl + adjacent addition into VPMADDWD#179326folkertdev merged 2 commits intollvm:mainfrom
shl + adjacent addition into VPMADDWD#179326Conversation
RKSimon
left a comment
There was a problem hiding this comment.
A few thoughts - feel free to ignore
| return SDValue(); | ||
|
|
||
| N0 = ShVal.getOperand(0); | ||
| if (N0.getValueType() != TruncVT) |
There was a problem hiding this comment.
I suppose technically we can support N0.getValueType() smaller than i16 and we just sign-extend to TruncVT?
There was a problem hiding this comment.
True, but I don't think that would ever come up organically?
| // A shift by more than 15 would overflow an i16. | ||
| if (ShiftAmount > 15) | ||
| return SDValue(); | ||
| MulConsts.push_back(DAG.getConstant(1u << ShiftAmount, DL, MVT::i16)); |
There was a problem hiding this comment.
Not sure but we might be able to use matchUnaryPredicate and then FoldConstantArithmetic to do all of this for us?
There was a problem hiding this comment.
Would that be simpler? After a cursory look I think the check for the shift being at most 15 would still require explicit looping over each element.
|
@llvm/pr-subscribers-backend-x86 Author: Folkert de Vries (folkertdev) ChangesI added an optimization for To make the shift semantically equal to the multiplication case, I'm bailing on shifts by more than 15, because code-wise I suspect that I'm missing some convenient way to access the integer values of a constant vector. 2 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 144d6451b981f..ee6a5ecc1165e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -58696,8 +58696,8 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
// (extract_elt Mul, 3),
// (extract_elt Mul, 5),
// ...
- // and identify Mul. Mul must be either ISD::MUL, or can be ISD::SIGN_EXTEND
- // in which case we add a trivial multiplication by 1.
+ // and identify Mul. Mul must be either ISD::MUL, ISD::SHL, or can be
+ // ISD::SIGN_EXTEND in which case we add a trivial multiplication by 1.
SDValue Mul;
for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; i += 2) {
SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
@@ -58728,7 +58728,7 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
// with 2X number of vector elements than the BUILD_VECTOR.
// Both extracts must be from same MUL.
Mul = Vec0L;
- if ((Mul.getOpcode() != ISD::MUL &&
+ if ((Mul.getOpcode() != ISD::MUL && Mul.getOpcode() != ISD::SHL &&
Mul.getOpcode() != ISD::SIGN_EXTEND) ||
Mul.getValueType().getVectorNumElements() != 2 * e)
return SDValue();
@@ -58751,6 +58751,38 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
N0 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(0));
N1 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(1));
+ } else if (Mul.getOpcode() == ISD::SHL) {
+ SDValue ShVal = Mul.getOperand(0);
+ if (ShVal.getOpcode() != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ N0 = ShVal.getOperand(0);
+ if (N0.getValueType() != TruncVT)
+ return SDValue();
+
+ unsigned NumElts = TruncVT.getVectorNumElements();
+ SmallVector<SDValue, 32> MulConsts;
+ MulConsts.reserve(NumElts);
+
+ auto *BV = dyn_cast<BuildVectorSDNode>(Mul.getOperand(1));
+ if (!BV || BV->getNumOperands() != NumElts)
+ return SDValue();
+
+ for (unsigned i = 0; i != NumElts; ++i) {
+ SDValue E = BV->getOperand(i);
+ if (E.isUndef())
+ return SDValue();
+ auto *C = dyn_cast<ConstantSDNode>(E);
+ if (!C)
+ return SDValue();
+ unsigned ShiftAmount = C->getZExtValue();
+ // A shift by more than 15 would overflow an i16.
+ if (ShiftAmount > 15)
+ return SDValue();
+ MulConsts.push_back(DAG.getConstant(1u << ShiftAmount, DL, MVT::i16));
+ }
+
+ N1 = DAG.getBuildVector(TruncVT, DL, MulConsts);
} else {
assert(Mul.getOpcode() == ISD::SIGN_EXTEND);
diff --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 231b9f97a5e3f..656aff18f02ef 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -360,3 +360,115 @@ define <8 x i32> @sext_pairwise_add(<16 x i16> %x) {
%4 = add nsw <8 x i32> %2, %3
ret <8 x i32> %4
}
+
+define <8 x i32> @combine_with_mul(<16 x i16> %v) {
+; SSE-LABEL: combine_with_mul:
+; SSE: # %bb.0: # %bb1
+; SSE-NEXT: movdqa {{.*#+}} xmm2 = [4096,1,4096,1,4096,1,4096,1]
+; SSE-NEXT: pmaddwd %xmm2, %xmm0
+; SSE-NEXT: pmaddwd %xmm2, %xmm1
+; SSE-NEXT: retq
+;
+; AVX1-LABEL: combine_with_mul:
+; AVX1: # %bb.0: # %bb1
+; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
+; AVX1-NEXT: vbroadcastss {{.*#+}} xmm2 = [4096,1,4096,1,4096,1,4096,1]
+; AVX1-NEXT: vpmaddwd %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vpmaddwd %xmm2, %xmm0, %xmm0
+; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
+; AVX1-NEXT: retq
+;
+; AVX2-LABEL: combine_with_mul:
+; AVX2: # %bb.0: # %bb1
+; AVX2-NEXT: vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 # [4096,1,4096,1,4096,1,4096,1,4096,1,4096,1,4096,1,4096,1]
+; AVX2-NEXT: retq
+bb1:
+ %0 = sext <16 x i16> %v to <16 x i32>
+ %1 = mul nsw <16 x i32> %0, <i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1>
+ %2 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+ %3 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+ %4 = add <8 x i32> %2, %3
+ ret <8 x i32> %4
+}
+
+define <8 x i32> @combine_with_shl(<16 x i16> %v) {
+; SSE-LABEL: combine_with_shl:
+; SSE: # %bb.0: # %bb1
+; SSE-NEXT: movdqa {{.*#+}} xmm2 = [4096,1,4096,1,4096,1,4096,1]
+; SSE-NEXT: pmaddwd %xmm2, %xmm0
+; SSE-NEXT: pmaddwd %xmm2, %xmm1
+; SSE-NEXT: retq
+;
+; AVX1-LABEL: combine_with_shl:
+; AVX1: # %bb.0: # %bb1
+; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
+; AVX1-NEXT: vbroadcastss {{.*#+}} xmm2 = [4096,1,4096,1,4096,1,4096,1]
+; AVX1-NEXT: vpmaddwd %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vpmaddwd %xmm2, %xmm0, %xmm0
+; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0
+; AVX1-NEXT: retq
+;
+; AVX2-LABEL: combine_with_shl:
+; AVX2: # %bb.0: # %bb1
+; AVX2-NEXT: vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 # [4096,1,4096,1,4096,1,4096,1,4096,1,4096,1,4096,1,4096,1]
+; AVX2-NEXT: retq
+bb1:
+ %0 = sext <16 x i16> %v to <16 x i32>
+ %1 = shl nsw <16 x i32> %0, <i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0>
+ %2 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+ %3 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+ %4 = add <8 x i32> %2, %3
+ ret <8 x i32> %4
+}
+
+; This version cannot use `vpmaddwd` because the multiply would overflow an i16.
+define <8 x i32> @combine_with_shl_overflow(<16 x i16> %v) {
+; SSE-LABEL: combine_with_shl_overflow:
+; SSE: # %bb.0: # %bb1
+; SSE-NEXT: pxor %xmm2, %xmm2
+; SSE-NEXT: pxor %xmm4, %xmm4
+; SSE-NEXT: punpckhwd {{.*#+}} xmm4 = xmm4[4],xmm1[4],xmm4[5],xmm1[5],xmm4[6],xmm1[6],xmm4[7],xmm1[7]
+; SSE-NEXT: pxor %xmm3, %xmm3
+; SSE-NEXT: punpcklwd {{.*#+}} xmm3 = xmm3[0],xmm1[0],xmm3[1],xmm1[1],xmm3[2],xmm1[2],xmm3[3],xmm1[3]
+; SSE-NEXT: phaddd %xmm4, %xmm3
+; SSE-NEXT: pxor %xmm1, %xmm1
+; SSE-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7]
+; SSE-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
+; SSE-NEXT: phaddd %xmm1, %xmm2
+; SSE-NEXT: movdqa %xmm2, %xmm0
+; SSE-NEXT: movdqa %xmm3, %xmm1
+; SSE-NEXT: retq
+;
+; AVX1-LABEL: combine_with_shl_overflow:
+; AVX1: # %bb.0: # %bb1
+; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1
+; AVX1-NEXT: vpxor %xmm2, %xmm2, %xmm2
+; AVX1-NEXT: vpunpcklwd {{.*#+}} xmm3 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
+; AVX1-NEXT: vpunpcklwd {{.*#+}} xmm4 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
+; AVX1-NEXT: vphaddd %xmm3, %xmm4, %xmm3
+; AVX1-NEXT: vpunpckhwd {{.*#+}} xmm1 = xmm2[4],xmm1[4],xmm2[5],xmm1[5],xmm2[6],xmm1[6],xmm2[7],xmm1[7]
+; AVX1-NEXT: vpunpckhwd {{.*#+}} xmm0 = xmm2[4],xmm0[4],xmm2[5],xmm0[5],xmm2[6],xmm0[6],xmm2[7],xmm0[7]
+; AVX1-NEXT: vphaddd %xmm1, %xmm0, %xmm0
+; AVX1-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0
+; AVX1-NEXT: vinsertf128 $1, %xmm3, %ymm3, %ymm1
+; AVX1-NEXT: vshufpd {{.*#+}} ymm0 = ymm1[0],ymm0[0],ymm1[3],ymm0[3]
+; AVX1-NEXT: retq
+;
+; AVX2-LABEL: combine_with_shl_overflow:
+; AVX2: # %bb.0: # %bb1
+; AVX2-NEXT: vpmovzxwd {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; AVX2-NEXT: vextracti128 $1, %ymm0, %xmm0
+; AVX2-NEXT: vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; AVX2-NEXT: vpslld $16, %ymm0, %ymm0
+; AVX2-NEXT: vpslld $16, %ymm1, %ymm1
+; AVX2-NEXT: vphaddd %ymm0, %ymm1, %ymm0
+; AVX2-NEXT: vpermq {{.*#+}} ymm0 = ymm0[0,2,1,3]
+; AVX2-NEXT: retq
+bb1:
+ %0 = sext <16 x i16> %v to <16 x i32>
+ %1 = shl nsw <16 x i32> %0, splat (i32 16)
+ %2 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+ %3 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+ %4 = add <8 x i32> %2, %3
+ ret <8 x i32> %4
+}
|
I added an optimization for
VPMADDWDearlier in #174149. That one is used in the adler32 checksum. That PR missed another pattern, used in base64 decoding, that uses ashlinstead of amul, but also should optimize toVPMADDWD.To make the shift semantically equal to the multiplication case, I'm bailing on shifts by more than 15, because
1 << 16is not representable in ani16.code-wise I suspect that I'm missing some convenient way to access the integer values of a constant vector.