Skip to content

Commit 3150b73

Browse files
authored
[MLIR][XeGPU] Clean up helpers in XeGPUPropagateLayout (#175857)
In XeGPUPropagateLayout.cpp, the helper getDefaultSIMTLayoutInfo is implemented via multiple overloads that differ significantly in semantics, not just parameter types. Reusing the same function name for these semantically different behaviors makes call sites harder to read and reason about and increases the maintenance burden. This PR improves readability and maintainability of layout propagation logic.
1 parent 1727337 commit 3150b73

File tree

1 file changed

+34
-47
lines changed

1 file changed

+34
-47
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -287,58 +287,47 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
287287
return LayoutInfo(xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1}));
288288
}
289289

290-
/// Helper to get the default layout for a vector type.
291-
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
292-
const xegpu::uArch::uArch *uArch,
293-
unsigned packingSize,
294-
bool isScattered = false) {
290+
/// Helper to get the default layout for 2D block operations.
291+
template <typename Ty>
292+
static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
293+
const xegpu::uArch::uArch *uArch,
294+
unsigned packingSize) {
295295
// Expecting a 1D or 2D vector.
296-
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
296+
assert((ty.getRank() == 1 || ty.getRank() == 2) &&
297297
"Expected 1D or 2D vector.");
298298
// Expecting int or float element type.
299-
assert(vectorTy.getElementType().isIntOrFloat() &&
299+
assert(ty.getElementType().isIntOrFloat() &&
300300
"Expected int or float element type.");
301301
// If the rank is 1, then return default layout for 1D vector.
302-
if (vectorTy.getRank() == 1)
303-
return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
302+
if (ty.getRank() == 1)
303+
return getDefaultSIMTLayoutInfo(ty.getContext(), 1, uArch);
304304
// Packing factor is determined by the element type bitwidth.
305-
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
305+
unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
306306
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
307-
if (isScattered) {
308-
return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
309-
{uArch->getSubgroupSize(), 1},
310-
{1, packingFactor}));
311-
}
312-
return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
313-
{1, uArch->getSubgroupSize()},
314-
{1, packingFactor}));
307+
return LayoutInfo(xegpu::LayoutAttr::get(
308+
ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
315309
}
316310

317311
/// Helper to get the default layout for a vector type.
318-
static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
319-
const xegpu::uArch::uArch *uArch,
320-
unsigned packingSize,
321-
bool isScattered = false) {
312+
static LayoutInfo
313+
getSIMTLayoutInforForScatterIO(VectorType vectorTy,
314+
const xegpu::uArch::uArch *uArch,
315+
unsigned packingSize) {
322316
// Expecting a 1D or 2D vector.
323-
assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
324-
"Expected 1D or 2D TensorDesc.");
317+
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
318+
"Expected 1D or 2D vector.");
325319
// Expecting int or float element type.
326-
assert(tdescTy.getElementType().isIntOrFloat() &&
320+
assert(vectorTy.getElementType().isIntOrFloat() &&
327321
"Expected int or float element type.");
328322
// If the rank is 1, then return default layout for 1D vector.
329-
if (tdescTy.getRank() == 1)
330-
return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
323+
if (vectorTy.getRank() == 1)
324+
return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
331325
// Packing factor is determined by the element type bitwidth.
332-
unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
333-
int subgroupSize = uArch->getSubgroupSize();
326+
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
334327
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
335-
if (isScattered) {
336-
return LayoutInfo(xegpu::LayoutAttr::get(
337-
tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
338-
}
339-
340-
return LayoutInfo(xegpu::LayoutAttr::get(
341-
tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
328+
return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
329+
{uArch->getSubgroupSize(), 1},
330+
{1, packingFactor}));
342331
}
343332

344333
/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -365,7 +354,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
365354
xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
366355
}
367356
// Otherwise, return the default layout for the vector type.
368-
return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
357+
return getSIMTLayoutInforForBlockIO(vectorTy, uArch, packingSize);
369358
}
370359

371360
//===----------------------------------------------------------------------===//
@@ -587,7 +576,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
587576
prefetchLayout =
588577
LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
589578
else
590-
prefetchLayout = getDefaultSIMTLayoutInfo(
579+
prefetchLayout = getSIMTLayoutInforForBlockIO(
591580
tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
592581

593582
prefetch.setLayoutAttr(
@@ -840,9 +829,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
840829
storeLayout =
841830
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
842831
else
843-
storeLayout =
844-
getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
845-
uArchInstruction->getPackedFormatBitSize());
832+
storeLayout = getSIMTLayoutInforForBlockIO(
833+
store.getValueType(), uArch,
834+
uArchInstruction->getPackedFormatBitSize());
846835
store.setLayoutAttr(
847836
dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
848837
}
@@ -1000,9 +989,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
1000989
loadLayout =
1001990
LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
1002991
else
1003-
loadLayout = getDefaultSIMTLayoutInfo(
1004-
payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
1005-
/*scattered*/ true);
992+
loadLayout = getSIMTLayoutInforForScatterIO(
993+
payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
1006994

1007995
// Mask operand should have 1D default layout.
1008996
maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1078,9 +1066,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
10781066
"Expected the first dimension of 2D tensor descriptor to be "
10791067
"equal to "
10801068
"subgroup size.");
1081-
payloadLayout = getDefaultSIMTLayoutInfo(
1082-
payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
1083-
/*scattered=*/true);
1069+
payloadLayout = getSIMTLayoutInforForScatterIO(
1070+
payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
10841071
}
10851072

10861073
maskLayout =

0 commit comments

Comments
 (0)