Skip to content

Commit c1df693

Browse files
authored
[TargetLowering] Use legally typed shifts to split chunks in expandDIVREMByConstant. (#187567)
This replaces LegalVT with HiLoVT and LegalWidth with HBitWidth as they are the same for all current uses. Then we rewrite the shifts to operate on LL and LH. There's a slight regression on RISC-V due to different node creation order leading to different DAG combine order. I have other refactoring I'd like to explore then I may try to fix that.
1 parent 7d7cd74 commit c1df693

File tree

2 files changed

+61
-47
lines changed

2 files changed

+61
-47
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8246,23 +8246,18 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
82468246
unsigned BitWidth = VT.getScalarSizeInBits();
82478247
unsigned BestChunkWidth = 0;
82488248

8249-
// Determine the legal scalar integer type for chunk operations.
8250-
EVT LegalVT = getTypeToTransformTo(*DAG.getContext(), VT);
8251-
unsigned LegalWidth = LegalVT.getScalarSizeInBits();
8252-
unsigned MaxChunk = std::min<unsigned>(LegalWidth, BitWidth);
8253-
82548249
// Search for I where 2^I % Divisor == 1
8255-
for (unsigned I = MaxChunk, E = MaxChunk / 2; I > E; --I) {
8250+
for (unsigned I = HBitWidth - 1, E = HBitWidth / 2; I > E; --I) {
82568251
APInt Mod = APInt::getOneBitSet(Divisor.getBitWidth(), I).urem(Divisor);
82578252

82588253
if (Mod.isOne()) {
8259-
// Ensure (NumChunks * MaxChunkValue) doesn't overflow LegalVT
8254+
// Ensure (NumChunks * MaxChunkValue) doesn't overflow HiLoVT
82608255
unsigned NumChunks = divideCeil(BitWidth, I);
82618256

8262-
// Ensure the sum won't overflow the hardware register (LegalWidth).
8257+
// Ensure the sum won't overflow the hardware register (HBitWidth).
82638258
// Summing N chunks adds ceil(log2(N)) extra carry bits to the width.
82648259
// Safety check: Base Chunk Width (I) + Carry Bits <= Register Width.
8265-
if (I + llvm::bit_width(NumChunks - 1) <= LegalWidth) {
8260+
if (I + llvm::bit_width(NumChunks - 1) <= HBitWidth) {
82668261
BestChunkWidth = I;
82678262
break;
82688263
}
@@ -8272,46 +8267,64 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
82728267
if (!BestChunkWidth)
82738268
return false;
82748269

8275-
SDValue In =
8276-
LL ? DAG.getNode(ISD::BUILD_PAIR, dl, VT, LL, LH) : N->getOperand(0);
8270+
assert(!LL == !LH && "Expected both input halves or no input halves!");
8271+
if (!LL)
8272+
std::tie(LL, LH) = DAG.SplitScalar(N->getOperand(0), dl, HiLoVT, HiLoVT);
8273+
82778274
if (TrailingZeros) {
82788275
// Save the shifted off bits if we need the remainder.
82798276
if (Opcode != ISD::UDIV) {
8280-
APInt Mask = APInt::getLowBitsSet(BitWidth, TrailingZeros);
8281-
PartialRem =
8282-
DAG.getNode(ISD::AND, dl, VT, In, DAG.getConstant(Mask, dl, VT));
8277+
APInt Mask = APInt::getLowBitsSet(HBitWidth, TrailingZeros);
8278+
PartialRem = DAG.getNode(ISD::AND, dl, HiLoVT, LL,
8279+
DAG.getConstant(Mask, dl, HiLoVT));
82838280
}
8284-
EVT ShiftVT = getShiftAmountTy(VT, DAG.getDataLayout());
8285-
In = DAG.getNode(ISD::SRL, dl, VT, In,
8286-
DAG.getShiftAmountConstant(TrailingZeros, ShiftVT, dl));
82878281

8288-
std::tie(LL, LH) = DAG.SplitScalar(In, dl, HiLoVT, HiLoVT);
8289-
} else if (!LL) {
8290-
std::tie(LL, LH) = DAG.SplitScalar(In, dl, HiLoVT, HiLoVT);
8282+
if (isOperationLegal(ISD::FSHR, HiLoVT))
8283+
LL = DAG.getNode(ISD::FSHR, dl, HiLoVT, LH, LL,
8284+
DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
8285+
else
8286+
LL = DAG.getNode(
8287+
ISD::OR, dl, HiLoVT,
8288+
DAG.getNode(ISD::SRL, dl, HiLoVT, LL,
8289+
DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl)),
8290+
DAG.getNode(ISD::SHL, dl, HiLoVT, LH,
8291+
DAG.getShiftAmountConstant(HBitWidth - TrailingZeros,
8292+
HiLoVT, dl)));
8293+
LH = DAG.getNode(ISD::SRL, dl, HiLoVT, LH,
8294+
DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
82918295
}
82928296

8293-
SDValue TotalSum = DAG.getConstant(0, dl, LegalVT);
8297+
Sum = DAG.getConstant(0, dl, HiLoVT);
82948298
SDValue Mask = DAG.getConstant(
8295-
APInt::getLowBitsSet(LegalWidth, BestChunkWidth), dl, LegalVT);
8299+
APInt::getLowBitsSet(HBitWidth, BestChunkWidth), dl, HiLoVT);
82968300

82978301
for (unsigned I = 0; I < BitWidth; I += BestChunkWidth) {
8298-
SDValue Shift = DAG.getShiftAmountConstant(I, VT, dl);
8299-
SDValue Chunk = DAG.getNode(ISD::SRL, dl, VT, In, Shift);
8300-
// Truncate to LegalVT
8301-
SDValue TruncChunk = DAG.getNode(ISD::TRUNCATE, dl, LegalVT, Chunk);
8302+
SDValue Chunk;
8303+
if (I == 0) {
8304+
Chunk = LL;
8305+
} else if (I >= HBitWidth) {
8306+
Chunk =
8307+
DAG.getNode(ISD::SRL, dl, HiLoVT, LH,
8308+
DAG.getShiftAmountConstant(I - HBitWidth, HiLoVT, dl));
8309+
} else if (isOperationLegal(ISD::FSHR, HiLoVT)) {
8310+
Chunk = DAG.getNode(ISD::FSHR, dl, HiLoVT, LH, LL,
8311+
DAG.getShiftAmountConstant(I, HiLoVT, dl));
8312+
} else {
8313+
Chunk = DAG.getNode(
8314+
ISD::OR, dl, HiLoVT,
8315+
DAG.getNode(ISD::SRL, dl, HiLoVT, LL,
8316+
DAG.getShiftAmountConstant(I, HiLoVT, dl)),
8317+
DAG.getNode(ISD::SHL, dl, HiLoVT, LH,
8318+
DAG.getShiftAmountConstant(HBitWidth - I, HiLoVT, dl)));
8319+
}
8320+
83028321
// For the last chunk, we might not need a mask if it's smaller than
83038322
// BestChunkWidth, but applying it is always safe.
8304-
SDValue MaskedChunk =
8305-
DAG.getNode(ISD::AND, dl, LegalVT, TruncChunk, Mask);
8306-
TotalSum = DAG.getNode(ISD::ADD, dl, LegalVT, TotalSum, MaskedChunk);
8323+
SDValue MaskedChunk = DAG.getNode(ISD::AND, dl, HiLoVT, Chunk, Mask);
8324+
Sum = DAG.getNode(ISD::ADD, dl, HiLoVT, Sum, MaskedChunk);
83078325
}
8308-
Sum = DAG.getNode(ISD::ZERO_EXTEND, dl, HiLoVT, TotalSum);
83098326
}
83108327

8311-
// If we didn't find a sum, we can't do the expansion.
8312-
if (!Sum)
8313-
return false;
8314-
83158328
// Perform a HiLoVT urem on the Sum using truncated divisor.
83168329
SDValue RemL =
83178330
DAG.getNode(ISD::UREM, dl, HiLoVT, Sum,
@@ -8346,8 +8359,7 @@ bool TargetLowering::expandDIVREMByConstant(SDNode *N,
83468359
RemL = DAG.getNode(ISD::SHL, dl, HiLoVT, RemL,
83478360
DAG.getShiftAmountConstant(TrailingZeros, HiLoVT, dl));
83488361

8349-
SDValue PartialRemLo = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, PartialRem);
8350-
RemL = DAG.getNode(ISD::ADD, dl, HiLoVT, RemL, PartialRemLo);
8362+
RemL = DAG.getNode(ISD::ADD, dl, HiLoVT, RemL, PartialRem);
83518363
}
83528364
Result.push_back(RemL);
83538365
Result.push_back(DAG.getConstant(0, dl, HiLoVT));

llvm/test/CodeGen/RISCV/urem-lkk.ll

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,21 +231,23 @@ define i64 @dont_fold_urem_i64(i64 %x) nounwind {
231231
; RV32IM: # %bb.0:
232232
; RV32IM-NEXT: slli a2, a1, 31
233233
; RV32IM-NEXT: srli a3, a0, 1
234-
; RV32IM-NEXT: andi a4, a1, 2046
234+
; RV32IM-NEXT: lui a4, 512
235+
; RV32IM-NEXT: srli a5, a1, 1
235236
; RV32IM-NEXT: srli a1, a1, 11
236237
; RV32IM-NEXT: or a2, a3, a2
237-
; RV32IM-NEXT: slli a4, a4, 10
238+
; RV32IM-NEXT: slli a5, a5, 11
238239
; RV32IM-NEXT: srli a3, a2, 21
239-
; RV32IM-NEXT: or a3, a3, a4
240-
; RV32IM-NEXT: lui a4, 21400
241-
; RV32IM-NEXT: slli a2, a2, 11
242-
; RV32IM-NEXT: srli a2, a2, 11
243-
; RV32IM-NEXT: add a2, a2, a3
244-
; RV32IM-NEXT: li a3, 49
245-
; RV32IM-NEXT: addi a4, a4, -2006
240+
; RV32IM-NEXT: or a3, a3, a5
241+
; RV32IM-NEXT: lui a5, 21400
242+
; RV32IM-NEXT: addi a4, a4, -1
243+
; RV32IM-NEXT: and a2, a2, a4
246244
; RV32IM-NEXT: add a1, a2, a1
247-
; RV32IM-NEXT: mulhu a2, a1, a4
248-
; RV32IM-NEXT: mul a2, a2, a3
245+
; RV32IM-NEXT: li a2, 49
246+
; RV32IM-NEXT: addi a5, a5, -2006
247+
; RV32IM-NEXT: and a3, a3, a4
248+
; RV32IM-NEXT: add a1, a1, a3
249+
; RV32IM-NEXT: mulhu a3, a1, a5
250+
; RV32IM-NEXT: mul a2, a3, a2
249251
; RV32IM-NEXT: sub a1, a1, a2
250252
; RV32IM-NEXT: slli a1, a1, 1
251253
; RV32IM-NEXT: andi a0, a0, 1

0 commit comments

Comments
 (0)