Skip to content

Commit 368f38b

Browse files
authored
[AMDGPU][SIInsertWaitcnts][NFC] SGPRInfo: Move score selection logic closer (#186518)
Selecting the score in SGPRInfo used to require an index which you would get by calling a getSgprScoresIdx(), which is defined in a different class. This patch moves the score selection logic into the SGPRinfo. This makes the interface simpler and more intuitive. Also given that SGPRInfo contains only two scores, this patch also replaces the score array with individual score variables. Should be NFC.
1 parent a60b3a8 commit 368f38b

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -721,11 +721,6 @@ class WaitcntBrackets {
721721
return T == Context->SmemAccessCounter || T == X_CNT;
722722
}
723723

724-
unsigned getSgprScoresIdx(InstCounterType T) const {
725-
assert(isSmemCounter(T) && "Invalid SMEM counter");
726-
return T == X_CNT ? 1 : 0;
727-
}
728-
729724
unsigned getOutstanding(InstCounterType T) const {
730725
return ScoreUBs[T] - ScoreLBs[T];
731726
}
@@ -754,7 +749,7 @@ class WaitcntBrackets {
754749

755750
unsigned getSGPRScore(MCRegUnit RU, InstCounterType T) const {
756751
auto It = SGPRs.find(RU);
757-
return It != SGPRs.end() ? It->second.Scores[getSgprScoresIdx(T)] : 0;
752+
return It != SGPRs.end() ? It->second.get(T) : 0;
758753
}
759754

760755
unsigned getVMemScore(VMEMID TID, InstCounterType T) const {
@@ -927,9 +922,8 @@ class WaitcntBrackets {
927922
for (MCRegUnit RU : regunits(Reg))
928923
VMem[toVMEMID(RU)].Scores[T] = Val;
929924
} else if (TRI.isSGPRReg(Context->MRI, Reg)) {
930-
auto STy = getSgprScoresIdx(T);
931925
for (MCRegUnit RU : regunits(Reg))
932-
SGPRs[RU].Scores[STy] = Val;
926+
SGPRs[RU].get(T) = Val;
933927
} else {
934928
llvm_unreachable("Register cannot be tracked/unknown register!");
935929
}
@@ -974,14 +968,24 @@ class WaitcntBrackets {
974968
bool empty() const { return all_of(Scores, equal_to(0)) && !VMEMTypes; }
975969
};
976970

977-
struct SGPRInfo {
978-
// Wait cnt scores for every sgpr, the DS_CNT (corresponding to LGKMcnt
979-
// pre-gfx12) or KM_CNT (gfx12+ only), and X_CNT (gfx1250) are relevant.
980-
// Row 0 represents the score for either DS_CNT or KM_CNT and row 1 keeps
981-
// the X_CNT score.
982-
std::array<unsigned, 2> Scores = {0};
971+
/// Wait cnt scores for every sgpr, the DS_CNT (corresponding to LGKMcnt
972+
/// pre-gfx12) or KM_CNT (gfx12+ only), and X_CNT (gfx1250) are relevant.
973+
class SGPRInfo {
974+
/// Either DS_CNT or KM_CNT score.
975+
unsigned ScoreDsKmCnt = 0;
976+
unsigned ScoreXCnt = 0;
977+
978+
public:
979+
unsigned get(InstCounterType T) const {
980+
assert((T == DS_CNT || T == KM_CNT || T == X_CNT) && "Invalid counter");
981+
return T == X_CNT ? ScoreXCnt : ScoreDsKmCnt;
982+
}
983+
unsigned &get(InstCounterType T) {
984+
assert((T == DS_CNT || T == KM_CNT || T == X_CNT) && "Invalid counter");
985+
return T == X_CNT ? ScoreXCnt : ScoreDsKmCnt;
986+
}
983987

984-
bool empty() const { return !Scores[0] && !Scores[1]; }
988+
bool empty() const { return !ScoreDsKmCnt && !ScoreXCnt; }
985989
};
986990

987991
DenseMap<VMEMID, VMEMInfo> VMem; // VGPR + LDS DMA
@@ -1359,7 +1363,7 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
13591363
SmallVector<MCRegUnit> SortedSMEMIDs(SGPRs.keys());
13601364
sort(SortedSMEMIDs);
13611365
for (auto ID : SortedSMEMIDs) {
1362-
unsigned RegScore = SGPRs.at(ID).Scores[getSgprScoresIdx(T)];
1366+
unsigned RegScore = SGPRs.at(ID).get(T);
13631367
if (RegScore <= LB)
13641368
continue;
13651369
unsigned RelScore = RegScore - LB - 1;
@@ -3089,12 +3093,10 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
30893093
StrictDom |= mergeScore(M, Info.Scores[T], Other.getVMemScore(RegID, T));
30903094

30913095
if (isSmemCounter(T)) {
3092-
unsigned Idx = getSgprScoresIdx(T);
30933096
for (auto &[RegID, Info] : SGPRs) {
30943097
auto It = Other.SGPRs.find(RegID);
3095-
unsigned OtherScore =
3096-
(It != Other.SGPRs.end()) ? It->second.Scores[Idx] : 0;
3097-
StrictDom |= mergeScore(M, Info.Scores[Idx], OtherScore);
3098+
unsigned OtherScore = (It != Other.SGPRs.end()) ? It->second.get(T) : 0;
3099+
StrictDom |= mergeScore(M, Info.get(T), OtherScore);
30983100
}
30993101
}
31003102
}

0 commit comments

Comments
 (0)