Skip to content

Commit d2dab97

Browse files
authored
[SPIR-V] Decompose bitcasts involving bool vectors (#187960)
OpTypeBool has no defined bitwidth in SPIR-V, so OpBitcast is invalid for boolean vector types. Decompose `<N x i1> <-> iN` bitcasts into element-wise extract/shift/OR and AND/icmp/insert sequences during IR preprocessing. Fixes: https://github.com/kuhar/iree/blob/amdgcn-spirv/spirv-repros/bitcast_crash.ll and #185815
1 parent 710c2f0 commit d2dab97

File tree

3 files changed

+510
-0
lines changed

3 files changed

+510
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ class SPIRVEmitIntrinsics
302302
void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
303303
bool processMaskedMemIntrinsic(IntrinsicInst &I);
304304
bool convertMaskedMemIntrinsics(Module &M);
305+
void preprocessBoolVectorBitcasts(Function &F);
305306

306307
void emitUnstructuredLoopControls(Function &F, IRBuilder<> &B);
307308

@@ -3157,6 +3158,7 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
31573158

31583159
preprocessUndefs(B);
31593160
preprocessCompositeConstants(B);
3161+
preprocessBoolVectorBitcasts(Func);
31603162
SmallVector<Instruction *> Worklist(
31613163
llvm::make_pointer_range(instructions(Func)));
31623164

@@ -3389,6 +3391,84 @@ bool SPIRVEmitIntrinsics::processMaskedMemIntrinsic(IntrinsicInst &I) {
33893391
return false;
33903392
}
33913393

3394+
// SPIR-V doesn't support bitcasts involving vector boolean type. Decompose such
3395+
// bitcasts into element-wise operations before building instructions
3396+
// worklist, so new instructions are properly visited and converted to
3397+
// SPIR-V intrinsics.
3398+
void SPIRVEmitIntrinsics::preprocessBoolVectorBitcasts(Function &F) {
3399+
struct BoolVecBitcast {
3400+
BitCastInst *BC;
3401+
FixedVectorType *BoolVecTy;
3402+
bool SrcIsBoolVec;
3403+
};
3404+
3405+
auto getAsBoolVec = [](Type *Ty) -> FixedVectorType * {
3406+
auto *VTy = dyn_cast<FixedVectorType>(Ty);
3407+
return (VTy && VTy->getElementType()->isIntegerTy(1)) ? VTy : nullptr;
3408+
};
3409+
3410+
SmallVector<BoolVecBitcast, 4> ToReplace;
3411+
for (auto &I : instructions(F)) {
3412+
auto *BC = dyn_cast<BitCastInst>(&I);
3413+
if (!BC)
3414+
continue;
3415+
if (auto *BVTy = getAsBoolVec(BC->getSrcTy()))
3416+
ToReplace.push_back({BC, BVTy, true});
3417+
else if (auto *BVTy = getAsBoolVec(BC->getDestTy()))
3418+
ToReplace.push_back({BC, BVTy, false});
3419+
}
3420+
3421+
for (auto &[BC, BoolVecTy, SrcIsBoolVec] : ToReplace) {
3422+
IRBuilder<> B(BC);
3423+
Value *Src = BC->getOperand(0);
3424+
unsigned BoolVecN = BoolVecTy->getNumElements();
3425+
// Use iN as the scalar intermediate type for the bool vector side.
3426+
Type *IntTy = B.getIntNTy(BoolVecN);
3427+
3428+
// Convert source to scalar integer.
3429+
Value *IntVal;
3430+
if (SrcIsBoolVec) {
3431+
// Extract each bool, zext, shift, and OR.
3432+
IntVal = ConstantInt::get(IntTy, 0);
3433+
for (unsigned I = 0; I < BoolVecN; ++I) {
3434+
Value *Elem = B.CreateExtractElement(Src, B.getInt32(I));
3435+
Value *Ext = B.CreateZExt(Elem, IntTy);
3436+
if (I > 0)
3437+
Ext = B.CreateShl(Ext, ConstantInt::get(IntTy, I));
3438+
IntVal = B.CreateOr(IntVal, Ext);
3439+
}
3440+
} else {
3441+
// Source is a non-bool type. If it's already a scalar integer, use it
3442+
// directly, otherwise bitcast to iN first.
3443+
IntVal = Src;
3444+
if (!Src->getType()->isIntegerTy())
3445+
IntVal = B.CreateBitCast(Src, IntTy);
3446+
}
3447+
3448+
// Convert scalar integer to destination type.
3449+
Value *Result;
3450+
if (!SrcIsBoolVec) {
3451+
// Test each bit with AND + icmp.
3452+
Result = PoisonValue::get(BoolVecTy);
3453+
for (unsigned I = 0; I < BoolVecN; ++I) {
3454+
Value *Mask = ConstantInt::get(IntTy, APInt::getOneBitSet(BoolVecN, I));
3455+
Value *And = B.CreateAnd(IntVal, Mask);
3456+
Value *Cmp = B.CreateICmpNE(And, ConstantInt::get(IntTy, 0));
3457+
Result = B.CreateInsertElement(Result, Cmp, B.getInt32(I));
3458+
}
3459+
} else {
3460+
// Destination is a non-bool type. If it's a scalar integer, use IntVal
3461+
// directly, otherwise bitcast from iN.
3462+
Result = IntVal;
3463+
if (!BC->getDestTy()->isIntegerTy())
3464+
Result = B.CreateBitCast(IntVal, BC->getDestTy());
3465+
}
3466+
3467+
BC->replaceAllUsesWith(Result);
3468+
BC->eraseFromParent();
3469+
}
3470+
}
3471+
33923472
bool SPIRVEmitIntrinsics::convertMaskedMemIntrinsics(Module &M) {
33933473
bool Changed = false;
33943474

0 commit comments

Comments
 (0)