Skip to content

Commit 6bbbdba

Browse files
authored
[AMDGPU] Simplicy the logic in checkWMMACoexecutionHazards, NFC (#200717)
1 parent daac50b commit 6bbbdba

1 file changed

Lines changed: 58 additions & 64 deletions

File tree

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Lines changed: 58 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,45 +2085,50 @@ static bool isCoexecutableVALUInst(const MachineInstr &MI) {
20852085
!SIInstrInfo::isSWMMAC(MI) && !SIInstrInfo::isLDSDMA(MI);
20862086
}
20872087

2088-
static bool IsWMMAHazardInstInCategory(const MachineInstr &MI,
2089-
const SIInstrInfo *TII, unsigned Latency,
2090-
unsigned Category) {
2091-
assert(TII->isXDLWMMA(MI) && (Latency == 8 || Latency == 16) &&
2092-
"Handle me if the xdl wmma instruction latency changes");
2093-
2094-
switch (Category) {
2095-
case 0: // Dense WMMA Instructions:
2096-
// WMMA_*F16, WMMA_*BF16
2097-
// WMMA_*FP8FP8
2098-
// WMMA_*FP8BF8
2099-
// WMMA_*BF8FP8
2100-
// WMMA_*BF8BF8
2101-
// WMMA_*F8F6F4 if SRCA & SRCB != F8
2102-
return Latency == 8 && SIInstrInfo::isWMMA(MI);
2103-
2104-
case 1: // Dense WMMA Instructions:
2105-
// WMMA_IU8
2106-
// WMMA_IU4
2107-
// WMMA_*F8F6F4 if SRCA OR SRCB == F8
2108-
return Latency == 16 && SIInstrInfo::isWMMA(MI);
2109-
2110-
case 2: // Dense SWMMAC Instructions
2111-
// SWMMAC_*F16, SWMMAC_*BF16,
2112-
// SWMMAC_*FP8FP8
2113-
// SWMMAC_*BF8FP8
2114-
// SWMMAC_*FP8BF8
2115-
// SWMMAC_*BF8BF8
2116-
return Latency == 8 && SIInstrInfo::isSWMMAC(MI);
2117-
2118-
case 3: // Sparse WMMA Instructions:
2119-
// SWMMAC_IU8
2120-
// SWMMAC_IU4
2121-
return Latency == 16 && SIInstrInfo::isSWMMAC(MI);
2122-
default:
2088+
// Classify XDL WMMA instructions into co-execution hazard categories
2089+
// (Refer to SPG 4.6.12.1), mainly based on instruction latency.
2090+
//
2091+
// Category 0: WMMA with Latency 8
2092+
// WMMA_*F16, WMMA_*BF16
2093+
// WMMA_*FP8FP8
2094+
// WMMA_*FP8BF8
2095+
// WMMA_*BF8FP8
2096+
// WMMA_*BF8BF8
2097+
// WMMA_*F8F6F4 if SRCA & SRCB != F8
2098+
//
2099+
// Category 1: WMMA Latency 16
2100+
// WMMA_IU8
2101+
// WMMA_*F8F6F4 if SRCA OR SRCB == F8
2102+
//
2103+
// Category 2: SWMMAC with Latency 8
2104+
// SWMMAC_*F16, SWMMAC_*BF16,
2105+
// SWMMAC_*FP8FP8
2106+
// SWMMAC_*BF8FP8
2107+
// SWMMAC_*FP8BF8
2108+
// SWMMAC_*BF8BF8
2109+
//
2110+
// Category 3: SWMMAC with Latency 16
2111+
// SWMMAC_IU8
2112+
static unsigned
2113+
getWMMAHazardInstInCategory(const MachineInstr &MI, const SIInstrInfo *TII,
2114+
const TargetSchedModel &SchedModel) {
2115+
assert(TII->isXDLWMMA(MI) && "must be xdl wmma");
2116+
bool IsSWMMAC = SIInstrInfo::isSWMMAC(MI);
2117+
unsigned Category = 0;
2118+
2119+
unsigned Latency = SchedModel.computeInstrLatency(&MI);
2120+
switch (Latency) {
2121+
case 8:
2122+
Category = IsSWMMAC ? 2 : 0;
21232123
break;
2124+
case 16:
2125+
Category = IsSWMMAC ? 3 : 1;
2126+
break;
2127+
default:
2128+
llvm_unreachable("unexpected xdl wmma latency");
21242129
} // end switch.
21252130

2126-
return false;
2131+
return Category;
21272132
}
21282133

21292134
int GCNHazardRecognizer::checkWMMACoexecutionHazards(MachineInstr *MI) const {
@@ -2147,51 +2152,40 @@ int GCNHazardRecognizer::checkWMMACoexecutionHazards(MachineInstr *MI) const {
21472152
if (!TII->isXDLWMMA(I))
21482153
return false;
21492154

2150-
unsigned Latency = TSchedModel.computeInstrLatency(&I);
2151-
if (!IsWMMAHazardInstInCategory(I, TII, Latency, Category))
2152-
return false;
2153-
2155+
Category = getWMMAHazardInstInCategory(I, TII, TSchedModel);
21542156
return hasWMMAToWMMARegOverlap(I, *MI);
21552157
};
21562158

21572159
auto IsVALUHazardFn = [MI, TII, &Category, this](const MachineInstr &I) {
21582160
if (!TII->isXDLWMMA(I))
21592161
return false;
21602162

2161-
unsigned Latency = TSchedModel.computeInstrLatency(&I);
2162-
if (!IsWMMAHazardInstInCategory(I, TII, Latency, Category))
2163-
return false;
2164-
2163+
Category = getWMMAHazardInstInCategory(I, TII, TSchedModel);
21652164
return hasWMMAToVALURegOverlap(I, *MI);
21662165
};
21672166

2168-
int Limit = 0;
2169-
21702167
auto GetWaitStatesFn = [](const MachineInstr &I) {
21712168
return SIInstrInfo::isVALU(I) ? 1 : 0;
21722169
};
21732170

21742171
int WaitStatesNeeded = -1;
2172+
int ExistingVALUs = 0; // Existing number of VALU ops in between.
2173+
2174+
// getWaitStatesSince checks for a hazard between instruction 'I' and 'MI':
2175+
// - If a hazard exists: returns the number of VALUs in between and sets
2176+
// 'Category' via IsWMMAHazardFn/IsVALUHazardFn for instruction 'I'.
2177+
// - If no hazard exists: returns INT_MAX, making WaitStatesNeeded negative,
2178+
// so no V_NOP insertion is needed.
21752179
if (TII->isXDLWMMA(*MI)) {
2176-
for (Category = 0; WaitStatesNeeded < 0 && Category < 4; Category++) {
2177-
Limit = WMMAWaitStates[Category]; // for IsExpiredFn.
2178-
// 'getWaitStatesSince' returns the number of VALUs in between if hazard
2179-
// exists, and INT_MAX if there is no hazard. As a result, a negative
2180-
// WaitStatesNeeded here means no hazard, and we will continue to search
2181-
// for other categories.
2182-
WaitStatesNeeded =
2183-
Limit - getWaitStatesSince(IsWMMAHazardFn, Limit, GetWaitStatesFn);
2184-
}
2180+
const int WMMAWaitsLimit = 9; // Maximum of WMMAWaitStates
2181+
ExistingVALUs =
2182+
getWaitStatesSince(IsWMMAHazardFn, WMMAWaitsLimit, GetWaitStatesFn);
2183+
WaitStatesNeeded = WMMAWaitStates[Category] - ExistingVALUs;
21852184
} else { // Must be a co-executable VALU.
2186-
for (Category = 0; WaitStatesNeeded < 0 && Category < 4; Category++) {
2187-
Limit = VALUWaitStates[Category]; // for IsExpiredFn.
2188-
// 'getWaitStatesSince' returns the number of VALUs in between if hazard
2189-
// exists, and INT_MAX if there is no hazard. As a result, a negative
2190-
// WaitStatesNeeded here means no hazard, and we will continue to search
2191-
// for other categories.
2192-
WaitStatesNeeded =
2193-
Limit - getWaitStatesSince(IsVALUHazardFn, Limit, GetWaitStatesFn);
2194-
}
2185+
const int VALUWaitsLimit = 8; // Maximum of VALUWaitStates
2186+
ExistingVALUs =
2187+
getWaitStatesSince(IsVALUHazardFn, VALUWaitsLimit, GetWaitStatesFn);
2188+
WaitStatesNeeded = VALUWaitStates[Category] - ExistingVALUs;
21952189
}
21962190

21972191
return WaitStatesNeeded;

0 commit comments

Comments
 (0)