Skip to content

Conversation

@jwlawson
Copy link
Contributor

@jwlawson jwlawson commented Mar 2, 2021

Currently done as a pass before the 2d ramps are flattened. After that the loads are converted to a concat of multiple loads and it is hard to get the strides from the index expressions, mainly because Halide is cautious about overflows so can't simplify the expressions back to simple base + i * stride indexes it started with.

Introduces an AMXTile type, so that the tiles that have to be stored by Halide are of the right type, and so that LLVM can alloca the right thing. Trying to use <i32 x 256> causes problems as Halide tries to break the loads and stores into multiple vector loads. This type doesn't really need to be externally available, but I'm not sure if there's a way to have an internal only type. For the tiles that do not need to be stored (ie used directly in a tile matmul intrinsic) we don't need to use the AMX type, which should allow us to use the overloaded intrinsics to select the right instruction for the datatypes.

Currently (ab)uses the way Halide calls intrinsics, but these tile intrinsics are needed to load and store from memory, so the default call->setDoesNotAccessMemory(); is not valid. I'm not too sure how to handle this in better. There are also some hacks to get the tile_store intrinsic to work, as it really should return void, but Halide makes some assumptions about the return types of Call Exprs which caused problems.

@jwlawson jwlawson closed this Mar 18, 2021
@alexreinking alexreinking modified the milestone: v12.0.0 May 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants