@@ -721,11 +721,6 @@ class WaitcntBrackets {
721721 return T == Context->SmemAccessCounter || T == X_CNT;
722722 }
723723
724- unsigned getSgprScoresIdx (InstCounterType T) const {
725- assert (isSmemCounter (T) && " Invalid SMEM counter" );
726- return T == X_CNT ? 1 : 0 ;
727- }
728-
729724 unsigned getOutstanding (InstCounterType T) const {
730725 return ScoreUBs[T] - ScoreLBs[T];
731726 }
@@ -754,7 +749,7 @@ class WaitcntBrackets {
754749
755750 unsigned getSGPRScore (MCRegUnit RU, InstCounterType T) const {
756751 auto It = SGPRs.find (RU);
757- return It != SGPRs.end () ? It->second .Scores [ getSgprScoresIdx (T)] : 0 ;
752+ return It != SGPRs.end () ? It->second .get (T) : 0 ;
758753 }
759754
760755 unsigned getVMemScore (VMEMID TID, InstCounterType T) const {
@@ -927,9 +922,8 @@ class WaitcntBrackets {
927922 for (MCRegUnit RU : regunits (Reg))
928923 VMem[toVMEMID (RU)].Scores [T] = Val;
929924 } else if (TRI.isSGPRReg (Context->MRI , Reg)) {
930- auto STy = getSgprScoresIdx (T);
931925 for (MCRegUnit RU : regunits (Reg))
932- SGPRs[RU].Scores [STy] = Val;
926+ SGPRs[RU].get (T) = Val;
933927 } else {
934928 llvm_unreachable (" Register cannot be tracked/unknown register!" );
935929 }
@@ -974,14 +968,24 @@ class WaitcntBrackets {
974968 bool empty () const { return all_of (Scores, equal_to (0 )) && !VMEMTypes; }
975969 };
976970
977- struct SGPRInfo {
978- // Wait cnt scores for every sgpr, the DS_CNT (corresponding to LGKMcnt
979- // pre-gfx12) or KM_CNT (gfx12+ only), and X_CNT (gfx1250) are relevant.
980- // Row 0 represents the score for either DS_CNT or KM_CNT and row 1 keeps
981- // the X_CNT score.
982- std::array<unsigned , 2 > Scores = {0 };
971+ // / Wait cnt scores for every sgpr, the DS_CNT (corresponding to LGKMcnt
972+ // / pre-gfx12) or KM_CNT (gfx12+ only), and X_CNT (gfx1250) are relevant.
973+ class SGPRInfo {
974+ // / Either DS_CNT or KM_CNT score.
975+ unsigned ScoreDsKmCnt = 0 ;
976+ unsigned ScoreXCnt = 0 ;
977+
978+ public:
979+ unsigned get (InstCounterType T) const {
980+ assert ((T == DS_CNT || T == KM_CNT || T == X_CNT) && " Invalid counter" );
981+ return T == X_CNT ? ScoreXCnt : ScoreDsKmCnt;
982+ }
983+ unsigned &get (InstCounterType T) {
984+ assert ((T == DS_CNT || T == KM_CNT || T == X_CNT) && " Invalid counter" );
985+ return T == X_CNT ? ScoreXCnt : ScoreDsKmCnt;
986+ }
983987
984- bool empty () const { return !Scores[ 0 ] && !Scores[ 1 ] ; }
988+ bool empty () const { return !ScoreDsKmCnt && !ScoreXCnt ; }
985989 };
986990
987991 DenseMap<VMEMID, VMEMInfo> VMem; // VGPR + LDS DMA
@@ -1359,7 +1363,7 @@ void WaitcntBrackets::print(raw_ostream &OS) const {
13591363 SmallVector<MCRegUnit> SortedSMEMIDs (SGPRs.keys ());
13601364 sort (SortedSMEMIDs);
13611365 for (auto ID : SortedSMEMIDs) {
1362- unsigned RegScore = SGPRs.at (ID).Scores [ getSgprScoresIdx (T)] ;
1366+ unsigned RegScore = SGPRs.at (ID).get (T);
13631367 if (RegScore <= LB)
13641368 continue ;
13651369 unsigned RelScore = RegScore - LB - 1 ;
@@ -3089,12 +3093,10 @@ bool WaitcntBrackets::merge(const WaitcntBrackets &Other) {
30893093 StrictDom |= mergeScore (M, Info.Scores [T], Other.getVMemScore (RegID, T));
30903094
30913095 if (isSmemCounter (T)) {
3092- unsigned Idx = getSgprScoresIdx (T);
30933096 for (auto &[RegID, Info] : SGPRs) {
30943097 auto It = Other.SGPRs .find (RegID);
3095- unsigned OtherScore =
3096- (It != Other.SGPRs .end ()) ? It->second .Scores [Idx] : 0 ;
3097- StrictDom |= mergeScore (M, Info.Scores [Idx], OtherScore);
3098+ unsigned OtherScore = (It != Other.SGPRs .end ()) ? It->second .get (T) : 0 ;
3099+ StrictDom |= mergeScore (M, Info.get (T), OtherScore);
30983100 }
30993101 }
31003102 }
0 commit comments