Skip to content

Commit d0fe95a

Browse files
committed
[SLP] Enable vectorization of strided stores
1 parent 290d63c commit d0fe95a

File tree

2 files changed

+169
-262
lines changed

2 files changed

+169
-262
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 142 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,12 @@ static cl::opt<unsigned> MaxProfitableStride(
200200
"slp-max-stride", cl::init(8), cl::Hidden,
201201
cl::desc("The maximum stride, considered to be profitable."));
202202

203+
static cl::opt<bool>
204+
EnableStridedStores("slp-enable-strided-stores", cl::init(false),
205+
cl::Hidden,
206+
cl::desc("Enable SLP trees to be built from strided "
207+
"store chains."));
208+
203209
static cl::opt<bool>
204210
DisableTreeReorder("slp-disable-tree-reorder", cl::init(false), cl::Hidden,
205211
cl::desc("Disable tree reordering even if it is "
@@ -8865,8 +8871,9 @@ void BoUpSLP::buildReorderableOperands(
88658871
continue;
88668872
if (UserTE->getOpcode() == Instruction::InsertElement && I == 0)
88678873
continue;
8868-
if (UserTE->getOpcode() == Instruction::Store &&
8869-
UserTE->State == TreeEntry::Vectorize && I == 1)
8874+
if (UserTE->getOpcode() == Instruction::Store && I == 1 &&
8875+
(UserTE->State == TreeEntry::Vectorize ||
8876+
UserTE->State == TreeEntry::StridedVectorize))
88708877
continue;
88718878
if (UserTE->getOpcode() == Instruction::Load &&
88728879
(UserTE->State == TreeEntry::Vectorize ||
@@ -9280,7 +9287,6 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
92809287
Data.first->reorderOperands(Mask);
92819288
if (!isa<InsertElementInst, StoreInst>(Data.first->getMainOp()) ||
92829289
IsNotProfitableAltCodeNode(*Data.first) ||
9283-
Data.first->State == TreeEntry::StridedVectorize ||
92849290
Data.first->State == TreeEntry::CompressVectorize) {
92859291
reorderScalars(Data.first->Scalars, Mask);
92869292
reorderOrder(Data.first->ReorderIndices, MaskOrder,
@@ -10706,11 +10712,16 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
1070610712
Ptr0 = PointerOps[CurrentOrder.front()];
1070710713
PtrN = PointerOps[CurrentOrder.back()];
1070810714
}
10715+
Align CommonAlignment = computeCommonAlignment<StoreInst>(VL0);
1070910716
std::optional<int64_t> Dist =
1071010717
getPointersDiff(ScalarTy, Ptr0, ScalarTy, PtrN, *DL, *SE);
1071110718
// Check that the sorted pointer operands are consecutive.
1071210719
if (static_cast<uint64_t>(*Dist) == VL.size() - 1)
1071310720
return TreeEntry::Vectorize;
10721+
if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, CommonAlignment,
10722+
CurrentOrder, *Dist, Ptr0, PtrN,
10723+
SPtrInfo))
10724+
return TreeEntry::StridedVectorize;
1071410725
}
1071510726

1071610727
LLVM_DEBUG(dbgs() << "SLP: Non-consecutive store.\n");
@@ -12596,6 +12607,18 @@ void BoUpSLP::buildTreeRec(ArrayRef<Value *> VLRef, unsigned Depth,
1259612607
return;
1259712608
}
1259812609
case Instruction::Store: {
12610+
if (State == TreeEntry::StridedVectorize) {
12611+
TreeEntry *TE =
12612+
newTreeEntry(VL, TreeEntry::StridedVectorize, Bundle, S,
12613+
UserTreeIdx, ReuseShuffleIndices, CurrentOrder);
12614+
TreeEntryToStridedPtrInfoMap[TE] = SPtrInfo;
12615+
LLVM_DEBUG(
12616+
dbgs() << "SLP: added a new TreeEntry (strided StoreInst).\n";
12617+
TE->dump());
12618+
TE->setOperands(Operands);
12619+
buildTreeRec(TE->getOperand(0), Depth + 1, {TE, 0});
12620+
return;
12621+
}
1259912622
bool Consecutive = CurrentOrder.empty();
1260012623
if (!Consecutive)
1260112624
fixupOrderingIndices(CurrentOrder);
@@ -14260,10 +14283,18 @@ void BoUpSLP::transformNodes() {
1426014283
/*VariableMask=*/false, CommonAlignment,
1426114284
BaseSI),
1426214285
CostKind);
14263-
if (StridedCost < OriginalVecCost)
14286+
if (StridedCost < OriginalVecCost) {
1426414287
// Strided store is more profitable than reverse + consecutive store -
1426514288
// transform the node to strided store.
1426614289
E.State = TreeEntry::StridedVectorize;
14290+
Type *StrideTy = DL->getIndexType(cast<StoreInst>(E.Scalars.front())
14291+
->getPointerOperand()
14292+
->getType());
14293+
StridedPtrInfo SPtrInfo;
14294+
SPtrInfo.StrideVal = ConstantInt::getSigned(StrideTy, -1);
14295+
SPtrInfo.Ty = VecTy;
14296+
TreeEntryToStridedPtrInfoMap[&E] = SPtrInfo;
14297+
}
1426714298
} else if (!E.ReorderIndices.empty()) {
1426814299
// Check for interleaved stores.
1426914300
auto IsInterleaveMask = [&, &TTI = *TTI](ArrayRef<int> Mask) {
@@ -22013,12 +22044,21 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
2201322044
}
2201422045
Align CommonAlignment = computeCommonAlignment<StoreInst>(E->Scalars);
2201522046
Type *StrideTy = DL->getIndexType(SI->getPointerOperandType());
22047+
22048+
const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
22049+
Value *Stride = SPtrInfo.StrideVal;
22050+
assert(Stride && "Missing StridedPointerInfo for tree entry.");
22051+
Value *StrideVal =
22052+
Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true);
22053+
// vp_strided_store::stride is defined in bytes
22054+
StrideVal = Builder.CreateMul(
22055+
StrideVal,
22056+
ConstantInt::getSigned(
22057+
StrideTy, static_cast<int>(DL->getTypeAllocSize(ScalarTy))));
2201622058
auto *Inst = Builder.CreateIntrinsic(
2201722059
Intrinsic::experimental_vp_strided_store,
2201822060
{VecTy, Ptr->getType(), StrideTy},
22019-
{VecValue, Ptr,
22020-
ConstantInt::getSigned(
22021-
StrideTy, -static_cast<int>(DL->getTypeAllocSize(ScalarTy))),
22061+
{VecValue, Ptr, StrideVal,
2202222062
Builder.getAllOnesMask(VecTy->getElementCount()),
2202322063
Builder.getInt32(E->Scalars.size())});
2202422064
Inst->addParamAttr(
@@ -25368,9 +25408,10 @@ class StoreChainContext {
2536825408

2536925409
explicit StoreChainContext(ArrayRef<Value *> Ops,
2537025410
ArrayRef<SizePair> RangeSizes,
25371-
SmallVector<unsigned> &RangeSizesByIdx)
25411+
SmallVector<unsigned> &RangeSizesByIdx,
25412+
unsigned Stride)
2537225413
: Operands(Ops), RangeSizesStorage(RangeSizes),
25373-
RangeSizesByIdx(RangeSizesByIdx) {}
25414+
RangeSizesByIdx(RangeSizesByIdx), Stride(Stride) {}
2537425415

2537525416
/// Set up initial values using the already set Operands
2537625417
bool initializeContext(BoUpSLP &R, const DataLayout &DL,
@@ -25379,13 +25420,19 @@ class StoreChainContext {
2537925420
std::optional<unsigned> getCurrentVF() const;
2538025421
/// Return the maximum VF for the context
2538125422
unsigned getMaxVF() const { return MaxVF; }
25423+
/// Return the stride of the context
25424+
unsigned getStride() const { return Stride; }
2538225425
/// Attempt to vectorize Operands for the given VF
2538325426
/// Returns false if no more attempts should be made for the context
2538425427
bool vectorizeOneVF(const TargetTransformInfo &TTI, unsigned VF,
2538525428
BoUpSLP::ValueSet &VectorizedStores, bool &Changed,
2538625429
llvm::function_ref<std::optional<bool>(
2538725430
ArrayRef<Value *>, unsigned, unsigned, unsigned &)>
2538825431
VectorizeStoreChain);
25432+
/// Add an additional store to the chain
25433+
/// \p Store store too append to Operands
25434+
/// \p Idx position within TryToVectorize::StoreSeq
25435+
void addOperand(Value *Store, unsigned Idx);
2538925436

2539025437
private:
2539125438
bool isNotVectorized(const SizePair &P) const {
@@ -25482,10 +25529,17 @@ class StoreChainContext {
2548225529
unsigned Repeat = 1;
2548325530
/// Did any vectorization occur for the current iteration over CandidateVFs
2548425531
bool RepeatChanged = false;
25532+
/// For constant strided stores, what is the stride amount
25533+
const unsigned Stride = 0;
2548525534
/// Store information about failed vectorization attempts due to scheduling
2548625535
SmallDenseMap<Value *, SizePair> NonSchedulable;
2548725536
};
2548825537

25538+
void StoreChainContext::addOperand(Value *Store, unsigned Idx) {
25539+
Operands.push_back(Store);
25540+
RangeSizesStorage.push_back({Idx, 1});
25541+
}
25542+
2548925543
void StoreChainContext::markRangeVectorized(unsigned StartIdx, unsigned Length,
2549025544
unsigned &FirstUnvecStore,
2549125545
unsigned &MaxSliceEnd) {
@@ -25538,6 +25592,8 @@ bool StoreChainContext::initializeContext(BoUpSLP &R, const DataLayout &DL,
2553825592
ValueTy->getScalarType()));
2553925593
MinVF /= getNumElements(StoreTy);
2554025594
MinVF = std::max<unsigned>(2, MinVF);
25595+
if (Stride > 1)
25596+
MinVF = std::max<unsigned>(MinVF, MinProfitableStridedStores);
2554125597

2554225598
if (MaxVF < MinVF) {
2554325599
LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF
@@ -25903,46 +25959,93 @@ bool SLPVectorizerPass::vectorizeStores(
2590325959
bool Changed = false;
2590425960

2590525961
auto TryToVectorize = [&](const RelatedStoreInsts::DistToInstMap &StoreSeq) {
25906-
int64_t PrevDist = -1;
25907-
unsigned GlobalMaxVF = 0;
2590825962
SmallVector<unsigned> RangeSizesByIdx(StoreSeq.size(), 1);
2590925963
SmallVector<std::unique_ptr<StoreChainContext>> AllContexts;
2591025964
BoUpSLP::ValueList Operands;
2591125965
SmallVector<StoreChainContext::SizePair> RangeSizes;
25966+
25967+
// All chains that we're still building
25968+
struct PartialChainStatus {
25969+
// Has been added to AllContexts
25970+
// Wait to add to AllContexts until at least two stores are found
25971+
bool AddedToAllContexts;
25972+
union {
25973+
// If added, index into AllContexts
25974+
unsigned AllContextsIdx;
25975+
// Index into StoreSeq if not added to AllContexts yet
25976+
unsigned StoreSeqIdx;
25977+
};
25978+
// If not added to AllContexts, what is the single store in the chain
25979+
Value *FirstStore;
25980+
// What is the Stride of this chain
25981+
unsigned Stride;
25982+
};
25983+
SmallVector<SmallVector<PartialChainStatus, 1>> Chains(MaxProfitableStride +
25984+
1);
25985+
auto GetChainsKey = [&](int64_t Query) -> unsigned {
25986+
// Just modulo function, get the index into Chains for a given query
25987+
int64_t Rem = (Query % (int64_t)Chains.size());
25988+
if (Rem < 0)
25989+
Rem += (int64_t)Chains.size();
25990+
return (unsigned)Rem;
25991+
};
25992+
int64_t LastDist;
2591225993
for (auto [Idx, Data] : enumerate(StoreSeq)) {
2591325994
auto &[Dist, InstIdx] = Data;
25914-
if (Operands.empty() || Dist - PrevDist == 1) {
25915-
Operands.push_back(Stores[InstIdx]);
25916-
RangeSizes.emplace_back(Idx, 1);
25917-
PrevDist = Dist;
25918-
if (Idx != StoreSeq.size() - 1)
25995+
// Clean up chains that can't be continued
25996+
if (Idx > 0)
25997+
for (int64_t D = LastDist; D < Dist; ++D)
25998+
Chains[GetChainsKey(D)].clear();
25999+
LastDist = Dist;
26000+
26001+
// Track which stride lengths we found existing chains for
26002+
// Don't have to add new entries for these
26003+
SmallVector<bool> FoundStrides(MaxProfitableStride + 1, false);
26004+
for (auto &Status : Chains[GetChainsKey(Dist)]) {
26005+
if (Status.AddedToAllContexts) {
26006+
// Chain already in AllContexts()
26007+
AllContexts[Status.AllContextsIdx]->addOperand(Stores[InstIdx], Idx);
26008+
} else {
26009+
// Chain just a single element, not yet in AllContexts()
26010+
SmallVector<StoreChainContext::SizePair> RS = {
26011+
{Status.StoreSeqIdx, 1}, {Idx, 1}};
26012+
BoUpSLP::ValueList Ops = {Status.FirstStore, Stores[InstIdx]};
26013+
AllContexts.emplace_back(std::make_unique<StoreChainContext>(
26014+
Ops, RS, RangeSizesByIdx, Status.Stride));
26015+
Status.AllContextsIdx = AllContexts.size() - 1;
26016+
}
26017+
unsigned Key = GetChainsKey(Status.Stride + Dist);
26018+
Chains[Key].push_back({/*AddedToAllContexts*/ true,
26019+
Status.AllContextsIdx, /*FirstStore*/ nullptr,
26020+
Status.Stride});
26021+
FoundStrides[Status.Stride] = true;
26022+
}
26023+
26024+
// For any stride lengths that we didn't append to a chain for,
26025+
// instead start a new chain
26026+
for (auto Stride : seq<unsigned>(
26027+
1, (EnableStridedStores ? MaxProfitableStride : 1) + 1)) {
26028+
if (FoundStrides[Stride])
2591926029
continue;
25920-
}
25921-
25922-
if (Operands.size() > 1 &&
25923-
Visited
25924-
.insert({Operands.front(),
25925-
cast<StoreInst>(Operands.front())->getValueOperand(),
25926-
Operands.back(),
25927-
cast<StoreInst>(Operands.back())->getValueOperand(),
25928-
Operands.size()})
25929-
.second) {
25930-
AllContexts.emplace_back(std::make_unique<StoreChainContext>(
25931-
Operands, RangeSizes, RangeSizesByIdx));
25932-
if (!AllContexts.back()->initializeContext(R, *DL, *TTI))
25933-
AllContexts.pop_back();
25934-
else
25935-
GlobalMaxVF = std::max(GlobalMaxVF, AllContexts.back()->getMaxVF());
25936-
}
25937-
Operands.clear();
25938-
RangeSizes.clear();
25939-
if (Idx != StoreSeq.size() - 1) {
25940-
Operands.push_back(Stores[InstIdx]);
25941-
RangeSizes.emplace_back(Idx, 1);
25942-
PrevDist = Dist;
26030+
unsigned Key = GetChainsKey(Dist + Stride);
26031+
Chains[Key].push_back({/*AddedToAllContexts*/ false,
26032+
/*StoreSeqIdx*/ (unsigned)Idx, Stores[InstIdx],
26033+
Stride});
2594326034
}
2594426035
}
2594526036

26037+
unsigned GlobalMaxVF = 0;
26038+
for (const auto &CtxPtr : AllContexts)
26039+
if (CtxPtr->initializeContext(R, *DL, *TTI))
26040+
GlobalMaxVF = std::max(GlobalMaxVF, AllContexts.back()->getMaxVF());
26041+
26042+
// Prioritize non-strided chains (Stride = 1)
26043+
llvm::stable_sort(AllContexts,
26044+
[](const std::unique_ptr<StoreChainContext> &A,
26045+
const std::unique_ptr<StoreChainContext> &B) {
26046+
return A->getStride() < B->getStride();
26047+
});
26048+
2594626049
for (unsigned LimitVF = GlobalMaxVF; LimitVF > 0;
2594726050
LimitVF = bit_ceil(LimitVF) / 2) {
2594826051
for (auto &CtxPtr : AllContexts) {

0 commit comments

Comments
 (0)