[MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (#173978)#175933
Closed
tblah wants to merge 1 commit into
Closed
[MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (#173978)#175933tblah wants to merge 1 commit into
tblah wants to merge 1 commit into
Conversation
…73978) This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions. - Extracts scalar element types for bitwidth calculations. - Uses DenseElementsAttr for vector splat initializers. - Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR). Fixes llvm#173860 --------- Co-authored-by: Aniket Singh <[email protected]>
Member
|
@llvm/pr-subscribers-mlir-openmp @llvm/pr-subscribers-flang-openmp Author: Tom Eccles (tblah) ChangesThis patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions.
Fixes #173860 2 Files Affected:
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..5fcaea7f39c3c 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -150,32 +150,48 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
llvm_unreachable("unknown float type");
}
+/// Helper to create a splat attribute for vector types, or return the scalar
+/// attribute for scalar types.
+static Attribute getSplatOrScalarAttr(Type type, Attribute val) {
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return val;
+}
+
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
/// (otherwise) for the given float type.
static Attribute minMaxValueForFloat(Type type, bool min) {
- auto fltType = cast<FloatType>(type);
- return FloatAttr::get(
- type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+ Type elType = getElementTypeOrSelf(type);
+ auto fltType = cast<FloatType>(elType);
+ auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
+
+ return getSplatOrScalarAttr(type, FloatAttr::get(elType, val));
}
/// Returns an attribute with the signed integer minimum (if `min` is set) or
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForSignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ Type elType = getElementTypeOrSelf(type);
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
- : llvm::APInt::getSignedMaxValue(bitwidth));
+ auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
+ : llvm::APInt::getSignedMaxValue(bitwidth);
+
+ return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
}
/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ Type elType = getElementTypeOrSelf(type);
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
- : llvm::APInt::getAllOnes(bitwidth));
+ auto val =
+ min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
+
+ return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val));
}
/// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -203,7 +219,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
Operation *terminator =
&reduce.getReductions()[reductionIndex].front().back();
assert(isa<scf::ReduceReturnOp>(terminator) &&
- "expected reduce op to be terminated by redure return");
+ "expected reduce op to be terminated by reduce return");
builder.setInsertionPoint(terminator);
builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
terminator->getOperands());
@@ -237,6 +253,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
return decl;
}
+/// Returns true if the type is supported by llvm.atomicrmw.
+/// LLVM IR currently does not support atomic operations on vector types.
+/// See LLVM Language Reference Manual on 'atomicrmw'.
+static bool supportsAtomic(Type type) { return !isa<VectorType>(type); }
+
/// Creates an OpenMP reduction declaration that corresponds to the given SCF
/// reduction and returns it. Recognizes common reductions in order to identify
/// the neutral value, necessary for the OpenMP declaration. If the reduction
@@ -261,91 +282,119 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
// Match simple binary reductions that can be expressed with atomicrmw.
Type type = reduce.getOperands()[reductionIndex].getType();
Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+ // Handle scalar element type extraction for vector bitwidth safety.
+ Type elType = getElementTypeOrSelf(type);
+
+ // Arithmetic Reductions
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 0.0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
- reductionIndex);
+ omp::DeclareReductionOp decl = createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
+ APInt allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
omp::DeclareReductionOp decl = createDecl(
builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(
- type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
- reductionIndex);
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, allOnes)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
+ decl, reduce, reductionIndex)
+ : decl;
}
// Match simple binary reductions that cannot be expressed with atomicrmw.
// TODO: add atomic region using cmpxchg (which needs atomic load to be
// available as an op).
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
- return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 1.0));
+ return createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 1.0)));
}
+
if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
- return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 1));
+ return createDecl(
+ builder, symbolTable, reduce, reductionIndex,
+ getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 1)));
}
// Match select-based min/max reductions.
bool isMin;
- if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
+ // Floating Point Min/Max
+ if (matchSelectReduction<arith::CmpFOp, arith::SelectOp,
+ arith::CmpFPredicate>(
reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
{arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
- matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
- reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
- {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
+ matchSelectReduction<arith::CmpFOp, arith::SelectOp,
+ arith::CmpFPredicate>(
+ reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE},
+ {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) {
return createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForFloat(type, !isMin));
}
- if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
+
+ // Integer Min/Max
+ if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+ arith::CmpIPredicate>(
reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
{arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
- matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
- reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
- {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
+ matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+ arith::CmpIPredicate>(
+ reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge},
+ {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) {
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForSignedInt(type, !isMin));
- return addAtomicRMW(builder,
- isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
- decl, reduce, reductionIndex);
+ return supportsAtomic(type) ? addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::min
+ : LLVM::AtomicBinOp::max,
+ decl, reduce, reductionIndex)
+ : decl;
}
- if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
+
+ // Unsigned Integer Min/Max
+ if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+ arith::CmpIPredicate>(
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
{arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
- matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
- reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
- {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
+ matchSelectReduction<arith::CmpIOp, arith::SelectOp,
+ arith::CmpIPredicate>(
+ reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge},
+ {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) {
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
minMaxValueForUnsignedInt(type, !isMin));
- return addAtomicRMW(
- builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
- decl, reduce, reductionIndex);
+ return supportsAtomic(type) ? addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::umin
+ : LLVM::AtomicBinOp::umax,
+ decl, reduce, reductionIndex)
+ : decl;
}
return nullptr;
diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
new file mode 100644
index 0000000000000..018f8a03c8e34
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
+
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1>
+// CHECK: init {
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK: omp.yield(%[[INIT]] : vector<2xi1>)
+// CHECK: }
+// CHECK: combiner {
+// CHECK: ^bb0(%[[ARG0:.*]]: vector<2xi1>, %[[ARG1:.*]]: vector<2xi1>):
+// CHECK: %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : vector<2xi1>
+// CHECK: omp.yield(%[[RES]] : vector<2xi1>)
+// CHECK: }
+// CHECK-NOT: atomic
+
+func.func @vector_and_reduction() {
+ %v_mask = vector.constant_mask [1] : vector<2xi1>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> {
+ %val = vector.constant_mask [1] : vector<2xi1>
+ scf.reduce (%val : vector<2xi1>) {
+ ^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>):
+ %0 = arith.andi %lhs, %rhs : vector<2xi1>
+ scf.reduce.return %0 : vector<2xi1>
+ }
+ }
+ return
+}
\ No newline at end of file
|
Contributor
Author
|
Backport to 22.x for #173978 |
|
|
Contributor
Author
|
c-rhodes
approved these changes
Jan 15, 2026
Contributor
|
merged fe68b17 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions.
Fixes #173860