Skip to content

Commit fe68b17

Browse files
Aniketsingh54Aniket Singh
authored andcommitted
[MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (#173978)
This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions. - Extracts scalar element types for bitwidth calculations. - Uses DenseElementsAttr for vector splat initializers. - Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR). Fixes #173860 --------- Co-authored-by: Aniket Singh <[email protected]>
1 parent 967834a commit fe68b17

2 files changed

Lines changed: 134 additions & 56 deletions

File tree

mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp

Lines changed: 105 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -150,32 +150,48 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
150150
llvm_unreachable("unknown float type");
151151
}
152152

153+
/// Helper to create a splat attribute for vector types, or return the scalar
154+
/// attribute for scalar types.
155+
static Attribute getSplatOrScalarAttr(Type type, Attribute val) {
156+
if (auto vecType = dyn_cast<VectorType>(type))
157+
return DenseElementsAttr::get(vecType, val);
158+
return val;
159+
}
160+
153161
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
154162
/// (otherwise) for the given float type.
155163
static Attribute minMaxValueForFloat(Type type, bool min) {
156-
auto fltType = cast<FloatType>(type);
157-
return FloatAttr::get(
158-
type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
164+
Type elType = getElementTypeOrSelf(type);
165+
auto fltType = cast<FloatType>(elType);
166+
auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
167+
168+
return getSplatOrScalarAttr(type, FloatAttr::get(elType, val));
159169
}
160170

161171
/// Returns an attribute with the signed integer minimum (if `min` is set) or
162172
/// the maximum value (otherwise) for the given integer type, regardless of its
163173
/// signedness semantics (only the width is considered).
164174
static Attribute minMaxValueForSignedInt(Type type, bool min) {
165-
auto intType = cast<IntegerType>(type);
175+
Type elType = getElementTypeOrSelf(type);
176+
auto intType = cast<IntegerType>(elType);
166177
unsigned bitwidth = intType.getWidth();
167-
return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
168-
: llvm::APInt::getSignedMaxValue(bitwidth));
178+
auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
179+
: llvm::APInt::getSignedMaxValue(bitwidth);
180+
181+
return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
169182
}
170183

171184
/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
172185
/// the maximum value (otherwise) for the given integer type, regardless of its
173186
/// signedness semantics (only the width is considered).
174187
static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
175-
auto intType = cast<IntegerType>(type);
188+
Type elType = getElementTypeOrSelf(type);
189+
auto intType = cast<IntegerType>(elType);
176190
unsigned bitwidth = intType.getWidth();
177-
return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
178-
: llvm::APInt::getAllOnes(bitwidth));
191+
auto val =
192+
min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
193+
194+
return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
179195
}
180196

181197
/// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -203,7 +219,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
203219
Operation *terminator =
204220
&reduce.getReductions()[reductionIndex].front().back();
205221
assert(isa<scf::ReduceReturnOp>(terminator) &&
206-
"expected reduce op to be terminated by redure return");
222+
"expected reduce op to be terminated by reduce return");
207223
builder.setInsertionPoint(terminator);
208224
builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
209225
terminator->getOperands());
@@ -237,6 +253,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
237253
return decl;
238254
}
239255

256+
/// Returns true if the type is supported by llvm.atomicrmw.
257+
/// LLVM IR currently does not support atomic operations on vector types.
258+
/// See LLVM Language Reference Manual on 'atomicrmw'.
259+
static bool supportsAtomic(Type type) { return !isa<VectorType>(type); }
260+
240261
/// Creates an OpenMP reduction declaration that corresponds to the given SCF
241262
/// reduction and returns it. Recognizes common reductions in order to identify
242263
/// the neutral value, necessary for the OpenMP declaration. If the reduction
@@ -261,91 +282,119 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
261282
// Match simple binary reductions that can be expressed with atomicrmw.
262283
Type type = reduce.getOperands()[reductionIndex].getType();
263284
Block &reduction = reduce.getReductions()[reductionIndex].front();
285+
286+
// Handle scalar element type extraction for vector bitwidth safety.
287+
Type elType = getElementTypeOrSelf(type);
288+
289+
// Arithmetic Reductions
264290
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
265-
omp::DeclareReductionOp decl =
266-
createDecl(builder, symbolTable, reduce, reductionIndex,
267-
builder.getFloatAttr(type, 0.0));
268-
return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
269-
reductionIndex);
291+
omp::DeclareReductionOp decl = createDecl(
292+
builder, symbolTable, reduce, reductionIndex,
293+
getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 0.0)));
294+
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
295+
decl, reduce, reductionIndex)
296+
: decl;
270297
}
271298
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
272-
omp::DeclareReductionOp decl =
273-
createDecl(builder, symbolTable, reduce, reductionIndex,
274-
builder.getIntegerAttr(type, 0));
275-
return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
276-
reductionIndex);
299+
omp::DeclareReductionOp decl = createDecl(
300+
builder, symbolTable, reduce, reductionIndex,
301+
getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
302+
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
303+
decl, reduce, reductionIndex)
304+
: decl;
277305
}
278306
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
279-
omp::DeclareReductionOp decl =
280-
createDecl(builder, symbolTable, reduce, reductionIndex,
281-
builder.getIntegerAttr(type, 0));
282-
return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
283-
reductionIndex);
307+
omp::DeclareReductionOp decl = createDecl(
308+
builder, symbolTable, reduce, reductionIndex,
309+
getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
310+
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
311+
decl, reduce, reductionIndex)
312+
: decl;
284313
}
285314
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
286-
omp::DeclareReductionOp decl =
287-
createDecl(builder, symbolTable, reduce, reductionIndex,
288-
builder.getIntegerAttr(type, 0));
289-
return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
290-
reductionIndex);
315+
omp::DeclareReductionOp decl = createDecl(
316+
builder, symbolTable, reduce, reductionIndex,
317+
getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
318+
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
319+
decl, reduce, reductionIndex)
320+
: decl;
291321
}
292322
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
323+
APInt allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
293324
omp::DeclareReductionOp decl = createDecl(
294325
builder, symbolTable, reduce, reductionIndex,
295-
builder.getIntegerAttr(
296-
type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
297-
return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
298-
reductionIndex);
326+
getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, allOnes)));
327+
return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
328+
decl, reduce, reductionIndex)
329+
: decl;
299330
}
300331

301332
// Match simple binary reductions that cannot be expressed with atomicrmw.
302333
// TODO: add atomic region using cmpxchg (which needs atomic load to be
303334
// available as an op).
304335
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
305-
return createDecl(builder, symbolTable, reduce, reductionIndex,
306-
builder.getFloatAttr(type, 1.0));
336+
return createDecl(
337+
builder, symbolTable, reduce, reductionIndex,
338+
getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 1.0)));
307339
}
340+
308341
if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
309-
return createDecl(builder, symbolTable, reduce, reductionIndex,
310-
builder.getIntegerAttr(type, 1));
342+
return createDecl(
343+
builder, symbolTable, reduce, reductionIndex,
344+
getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 1)));
311345
}
312346

313347
// Match select-based min/max reductions.
314348
bool isMin;
315-
if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
349+
// Floating Point Min/Max
350+
if (matchSelectReduction<arith::CmpFOp, arith::SelectOp,
351+
arith::CmpFPredicate>(
316352
reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
317353
{arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
318-
matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
319-
reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
320-
{LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
354+
matchSelectReduction<arith::CmpFOp, arith::SelectOp,
355+
arith::CmpFPredicate>(
356+
reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE},
357+
{arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) {
321358
return createDecl(builder, symbolTable, reduce, reductionIndex,
322359
minMaxValueForFloat(type, !isMin));
323360
}
324-
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
361+
362+
// Integer Min/Max
363+
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
364+
arith::CmpIPredicate>(
325365
reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
326366
{arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
327-
matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
328-
reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
329-
{LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
367+
matchSelectReduction<arith::CmpIOp, arith::SelectOp,
368+
arith::CmpIPredicate>(
369+
reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge},
370+
{arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) {
330371
omp::DeclareReductionOp decl =
331372
createDecl(builder, symbolTable, reduce, reductionIndex,
332373
minMaxValueForSignedInt(type, !isMin));
333-
return addAtomicRMW(builder,
334-
isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
335-
decl, reduce, reductionIndex);
374+
return supportsAtomic(type) ? addAtomicRMW(builder,
375+
isMin ? LLVM::AtomicBinOp::min
376+
: LLVM::AtomicBinOp::max,
377+
decl, reduce, reductionIndex)
378+
: decl;
336379
}
337-
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
380+
381+
// Unsigned Integer Min/Max
382+
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
383+
arith::CmpIPredicate>(
338384
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339385
{arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
340-
matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
341-
reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
342-
{LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
386+
matchSelectReduction<arith::CmpIOp, arith::SelectOp,
387+
arith::CmpIPredicate>(
388+
reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge},
389+
{arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) {
343390
omp::DeclareReductionOp decl =
344391
createDecl(builder, symbolTable, reduce, reductionIndex,
345392
minMaxValueForUnsignedInt(type, !isMin));
346-
return addAtomicRMW(
347-
builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
348-
decl, reduce, reductionIndex);
393+
return supportsAtomic(type) ? addAtomicRMW(builder,
394+
isMin ? LLVM::AtomicBinOp::umin
395+
: LLVM::AtomicBinOp::umax,
396+
decl, reduce, reductionIndex)
397+
: decl;
349398
}
350399

351400
return nullptr;
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
2+
3+
// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1>
4+
// CHECK: init {
5+
// CHECK: %[[INIT:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
6+
// CHECK: omp.yield(%[[INIT]] : vector<2xi1>)
7+
// CHECK: }
8+
// CHECK: combiner {
9+
// CHECK: ^bb0(%[[ARG0:.*]]: vector<2xi1>, %[[ARG1:.*]]: vector<2xi1>):
10+
// CHECK: %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : vector<2xi1>
11+
// CHECK: omp.yield(%[[RES]] : vector<2xi1>)
12+
// CHECK: }
13+
// CHECK-NOT: atomic
14+
15+
func.func @vector_and_reduction() {
16+
%v_mask = vector.constant_mask [1] : vector<2xi1>
17+
%c0 = arith.constant 0 : index
18+
%c1 = arith.constant 1 : index
19+
%c2 = arith.constant 2 : index
20+
%result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> {
21+
%val = vector.constant_mask [1] : vector<2xi1>
22+
scf.reduce (%val : vector<2xi1>) {
23+
^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>):
24+
%0 = arith.andi %lhs, %rhs : vector<2xi1>
25+
scf.reduce.return %0 : vector<2xi1>
26+
}
27+
}
28+
return
29+
}

0 commit comments

Comments
 (0)