@@ -287,58 +287,47 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
287287 return LayoutInfo (xegpu::LayoutAttr::get (ctx, {1 , subgroupSize}, {1 , 1 }));
288288}
289289
290- // / Helper to get the default layout for a vector type .
291- static LayoutInfo getDefaultSIMTLayoutInfo (VectorType vectorTy,
292- const xegpu::uArch::uArch *uArch ,
293- unsigned packingSize ,
294- bool isScattered = false ) {
290+ // / Helper to get the default layout for 2D block operations .
291+ template < typename Ty>
292+ static LayoutInfo getSIMTLayoutInforForBlockIO (Ty ty ,
293+ const xegpu::uArch::uArch *uArch ,
294+ unsigned packingSize ) {
295295 // Expecting a 1D or 2D vector.
296- assert ((vectorTy .getRank () == 1 || vectorTy .getRank () == 2 ) &&
296+ assert ((ty .getRank () == 1 || ty .getRank () == 2 ) &&
297297 " Expected 1D or 2D vector." );
298298 // Expecting int or float element type.
299- assert (vectorTy .getElementType ().isIntOrFloat () &&
299+ assert (ty .getElementType ().isIntOrFloat () &&
300300 " Expected int or float element type." );
301301 // If the rank is 1, then return default layout for 1D vector.
302- if (vectorTy .getRank () == 1 )
303- return getDefaultSIMTLayoutInfo (vectorTy .getContext (), 1 , uArch);
302+ if (ty .getRank () == 1 )
303+ return getDefaultSIMTLayoutInfo (ty .getContext (), 1 , uArch);
304304 // Packing factor is determined by the element type bitwidth.
305- unsigned bitwidth = vectorTy .getElementType ().getIntOrFloatBitWidth ();
305+ unsigned bitwidth = ty .getElementType ().getIntOrFloatBitWidth ();
306306 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1 ;
307- if (isScattered) {
308- return LayoutInfo (xegpu::LayoutAttr::get (vectorTy.getContext (),
309- {uArch->getSubgroupSize (), 1 },
310- {1 , packingFactor}));
311- }
312- return LayoutInfo (xegpu::LayoutAttr::get (vectorTy.getContext (),
313- {1 , uArch->getSubgroupSize ()},
314- {1 , packingFactor}));
307+ return LayoutInfo (xegpu::LayoutAttr::get (
308+ ty.getContext (), {1 , uArch->getSubgroupSize ()}, {1 , packingFactor}));
315309}
316310
317311// / Helper to get the default layout for a vector type.
318- static LayoutInfo getDefaultSIMTLayoutInfo (xegpu::TensorDescType tdescTy,
319- const xegpu::uArch::uArch *uArch ,
320- unsigned packingSize ,
321- bool isScattered = false ) {
312+ static LayoutInfo
313+ getSIMTLayoutInforForScatterIO (VectorType vectorTy ,
314+ const xegpu::uArch::uArch *uArch ,
315+ unsigned packingSize ) {
322316 // Expecting a 1D or 2D vector.
323- assert ((tdescTy .getRank () == 1 || tdescTy .getRank () == 2 ) &&
324- " Expected 1D or 2D TensorDesc ." );
317+ assert ((vectorTy .getRank () == 1 || vectorTy .getRank () == 2 ) &&
318+ " Expected 1D or 2D vector ." );
325319 // Expecting int or float element type.
326- assert (tdescTy .getElementType ().isIntOrFloat () &&
320+ assert (vectorTy .getElementType ().isIntOrFloat () &&
327321 " Expected int or float element type." );
328322 // If the rank is 1, then return default layout for 1D vector.
329- if (tdescTy .getRank () == 1 )
330- return getDefaultSIMTLayoutInfo (tdescTy .getContext (), 1 , uArch);
323+ if (vectorTy .getRank () == 1 )
324+ return getDefaultSIMTLayoutInfo (vectorTy .getContext (), 1 , uArch);
331325 // Packing factor is determined by the element type bitwidth.
332- unsigned bitwidth = tdescTy.getElementType ().getIntOrFloatBitWidth ();
333- int subgroupSize = uArch->getSubgroupSize ();
326+ unsigned bitwidth = vectorTy.getElementType ().getIntOrFloatBitWidth ();
334327 int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1 ;
335- if (isScattered) {
336- return LayoutInfo (xegpu::LayoutAttr::get (
337- tdescTy.getContext (), {subgroupSize, 1 }, {1 , packingFactor}));
338- }
339-
340- return LayoutInfo (xegpu::LayoutAttr::get (
341- tdescTy.getContext (), {1 , subgroupSize}, {1 , packingFactor}));
328+ return LayoutInfo (xegpu::LayoutAttr::get (vectorTy.getContext (),
329+ {uArch->getSubgroupSize (), 1 },
330+ {1 , packingFactor}));
342331}
343332
344333// / Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -365,7 +354,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
365354 xegpu::LayoutAttr::get (vectorTy.getContext (), layout, data));
366355 }
367356 // Otherwise, return the default layout for the vector type.
368- return getDefaultSIMTLayoutInfo (vectorTy, uArch, packingSize);
357+ return getSIMTLayoutInforForBlockIO (vectorTy, uArch, packingSize);
369358}
370359
371360// ===----------------------------------------------------------------------===//
@@ -587,7 +576,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
587576 prefetchLayout =
588577 LayoutInfo (xegpu::LayoutAttr::get (tdescTy.getContext (), instData));
589578 else
590- prefetchLayout = getDefaultSIMTLayoutInfo (
579+ prefetchLayout = getSIMTLayoutInforForBlockIO (
591580 tdescTy, uArch, uArchInstruction->getPackedFormatBitSize ());
592581
593582 prefetch.setLayoutAttr (
@@ -840,9 +829,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
840829 storeLayout =
841830 LayoutInfo (xegpu::LayoutAttr::get (dataTy.getContext (), instData));
842831 else
843- storeLayout =
844- getDefaultSIMTLayoutInfo ( store.getValueType (), uArch,
845- uArchInstruction->getPackedFormatBitSize ());
832+ storeLayout = getSIMTLayoutInforForBlockIO (
833+ store.getValueType (), uArch,
834+ uArchInstruction->getPackedFormatBitSize ());
846835 store.setLayoutAttr (
847836 dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get ()));
848837 }
@@ -1000,9 +989,8 @@ void LayoutInfoPropagation::visitLoadGatherOp(
1000989 loadLayout =
1001990 LayoutInfo (xegpu::LayoutAttr::get (load.getContext (), instData));
1002991 else
1003- loadLayout = getDefaultSIMTLayoutInfo (
1004- payloadTy, uArch, uArch->getGeneralPackedFormatBitSize (),
1005- /* scattered*/ true );
992+ loadLayout = getSIMTLayoutInforForScatterIO (
993+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize ());
1006994
1007995 // Mask operand should have 1D default layout.
1008996 maskLayout = getDefaultSIMTLayoutInfo (load->getContext (), 1 , subgroupSize);
@@ -1078,9 +1066,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
10781066 " Expected the first dimension of 2D tensor descriptor to be "
10791067 " equal to "
10801068 " subgroup size." );
1081- payloadLayout = getDefaultSIMTLayoutInfo (
1082- payloadTy, uArch, uArch->getGeneralPackedFormatBitSize (),
1083- /* scattered=*/ true );
1069+ payloadLayout = getSIMTLayoutInforForScatterIO (
1070+ payloadTy, uArch, uArch->getGeneralPackedFormatBitSize ());
10841071 }
10851072
10861073 maskLayout =
0 commit comments