[MLIR][SCFToOpenMP] Fix crash when lowering vector reductions#173978
Conversation
This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions. - Extracts scalar element types for bitwidth calculations. - Uses DenseElementsAttr for vector splat initializers. - Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR). Fixes llvm#173860
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-openmp Author: Aniket Singh (Aniketsingh54) ChangesThis patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions.
Fixes #173860 2 Files Affected:
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..3d3c601d92d1b 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -153,29 +153,58 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
/// Returns an attribute with the minimum (if `min` is set) or the maximum value
/// (otherwise) for the given float type.
static Attribute minMaxValueForFloat(Type type, bool min) {
- auto fltType = cast<FloatType>(type);
- return FloatAttr::get(
- type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min));
+ // If the type is a vector, we need to find the neutral value for the
+ // underlying element type and then create a splat attribute.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto fltType = cast<FloatType>(elType);
+ auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min);
+
+ // For vector types, return a DenseElementsAttr (splat).
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+
+ return FloatAttr::get(type, val);
}
/// Returns an attribute with the signed integer minimum (if `min` is set) or
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForSignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ // Extract scalar element type to handle vector reductions.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth)
- : llvm::APInt::getSignedMaxValue(bitwidth));
+ auto val = min ? llvm::APInt::getSignedMinValue(bitwidth)
+ : llvm::APInt::getSignedMaxValue(bitwidth);
+
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return IntegerAttr::get(type, val);
}
/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
/// the maximum value (otherwise) for the given integer type, regardless of its
/// signedness semantics (only the width is considered).
static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
- auto intType = cast<IntegerType>(type);
+ // Extract scalar element type to handle vector reductions.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ auto intType = cast<IntegerType>(elType);
unsigned bitwidth = intType.getWidth();
- return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth)
- : llvm::APInt::getAllOnes(bitwidth));
+ auto val =
+ min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth);
+
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return IntegerAttr::get(type, val);
}
/// Creates an OpenMP reduction declaration and inserts it into the provided
@@ -203,7 +232,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
Operation *terminator =
&reduce.getReductions()[reductionIndex].front().back();
assert(isa<scf::ReduceReturnOp>(terminator) &&
- "expected reduce op to be terminated by redure return");
+ "expected reduce op to be terminated by reduce return");
builder.setInsertionPoint(terminator);
builder.replaceOpWithNewOp<omp::YieldOp>(terminator,
terminator->getOperands());
@@ -236,6 +265,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
omp::YieldOp::create(builder, reduce.getLoc(), ArrayRef<Value>());
return decl;
}
+/// Returns true if the type is supported by llvm.atomicrmw.
+/// LLVM IR does not support atomic operations on vector types.
+static bool supportsAtomic(Type type) {
+ return !isa<VectorType>(type);
+}
/// Creates an OpenMP reduction declaration that corresponds to the given SCF
/// reduction and returns it. Recognizes common reductions in order to identify
@@ -261,41 +295,55 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
// Match simple binary reductions that can be expressed with atomicrmw.
Type type = reduce.getOperands()[reductionIndex].getType();
Block &reduction = reduce.getReductions()[reductionIndex].front();
+
+ // Handle scalar element type extraction for vector bitwidth safety.
+ Type elType = type;
+ if (auto vecType = dyn_cast<VectorType>(type))
+ elType = vecType.getElementType();
+
+ // Helper to create splat (for vectors) or scalar attributes.
+ auto getAttr = [&](Attribute val) -> Attribute {
+ if (auto vecType = dyn_cast<VectorType>(type))
+ return DenseElementsAttr::get(vecType, val);
+ return val;
+ };
+
+ // Arithmetic Reductions
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getFloatAttr(elType, 0.0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
- omp::DeclareReductionOp decl =
- createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
- reductionIndex);
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, 0)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
- omp::DeclareReductionOp decl = createDecl(
- builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(
- type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
- reductionIndex);
+ auto allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth());
+ auto decl = createDecl(builder, symbolTable, reduce, reductionIndex,
+ getAttr(builder.getIntegerAttr(elType, allOnes)));
+ return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and,
+ decl, reduce, reductionIndex)
+ : decl;
}
// Match simple binary reductions that cannot be expressed with atomicrmw.
@@ -303,12 +351,11 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
// available as an op).
if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getFloatAttr(type, 1.0));
- }
- if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
+ getAttr(builder.getFloatAttr(elType, 1.0)));
+
+ if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction))
return createDecl(builder, symbolTable, reduce, reductionIndex,
- builder.getIntegerAttr(type, 1));
- }
+ getAttr(builder.getIntegerAttr(elType, 1)));
// Match select-based min/max reductions.
bool isMin;
@@ -329,10 +376,12 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
{LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
- minMaxValueForSignedInt(type, !isMin));
- return addAtomicRMW(builder,
- isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
- decl, reduce, reductionIndex);
+ minMaxValueForSignedInt(type, !isMin));
+ return supportsAtomic(type)
+ ? addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
+ decl, reduce, reductionIndex)
+ : decl;
}
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -342,10 +391,12 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
{LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
omp::DeclareReductionOp decl =
createDecl(builder, symbolTable, reduce, reductionIndex,
- minMaxValueForUnsignedInt(type, !isMin));
- return addAtomicRMW(
- builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
- decl, reduce, reductionIndex);
+ minMaxValueForUnsignedInt(type, !isMin));
+ return supportsAtomic(type)
+ ? addAtomicRMW(builder,
+ isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
+ decl, reduce, reductionIndex)
+ : decl;
}
return nullptr;
@@ -370,6 +421,13 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
SmallVector<omp::DeclareReductionOp> ompReductionDecls;
auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator());
for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
+ // Ensure validity of reduction type for vector bitwidth calculations.
+ Type reductionType = reduce.getOperands()[i].getType();
+ if (auto vecType = dyn_cast<VectorType>(reductionType))
+ (void)vecType.getElementType().getIntOrFloatBitWidth();
+ else
+ (void)reductionType.getIntOrFloatBitWidth();
+
omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i);
ompReductionDecls.push_back(decl);
if (!decl)
@@ -427,7 +485,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
Operation *cloneOp = builder.clone(op, mapper);
if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) {
assert(yieldOp && yieldOp.getResults().size() == 1 &&
- "expect YieldOp in reduction region to return one result");
+ "expect YieldOp in reduction region to return one result");
Value redVal = yieldOp.getResults()[0];
LLVM::StoreOp::create(rewriter, loc, redVal, pvtRedVar);
rewriter.eraseOp(yieldOp);
diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
new file mode 100644
index 0000000000000..38d7e3ec2aff1
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s
+
+// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1> init
+// CHECK-NEXT: ^bb0(%arg0: vector<2xi1>):
+// CHECK-NEXT: %[[CONST:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1>
+// CHECK-NEXT: omp.yield(%[[CONST]] : vector<2xi1>)
+
+func.func @vector_and_reduction() {
+ %v_mask = vector.constant_mask [1] : vector<2xi1>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> {
+ %val = vector.constant_mask [1] : vector<2xi1>
+ scf.reduce (%val : vector<2xi1>) {
+ ^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>):
+ %0 = arith.andi %lhs, %rhs : vector<2xi1>
+ scf.reduce.return %0 : vector<2xi1>
+ }
+ }
+ return
+}
\ No newline at end of file
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
tblah
left a comment
There was a problem hiding this comment.
Thanks for your patch! Just a few comments.
| Type elType = type; | ||
| if (auto vecType = dyn_cast<VectorType>(type)) | ||
| elType = vecType.getElementType(); |
There was a problem hiding this comment.
nit: perhaps this should be a helper function, as the same pattern repeats a lot.
| builder.getFloatAttr(type, 0.0)); | ||
| return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce, | ||
| reductionIndex); | ||
| auto decl = createDecl(builder, symbolTable, reduce, reductionIndex, |
There was a problem hiding this comment.
Why did you switch to auto here? Spelling out the type is preferred
| builder.getFloatAttr(type, 1.0)); | ||
| } | ||
| if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) { | ||
| getAttr(builder.getFloatAttr(elType, 1.0))); |
There was a problem hiding this comment.
missing end brace after this line?
| // CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1> init | ||
| // CHECK-NEXT: ^bb0(%arg0: vector<2xi1>): | ||
| // CHECK-NEXT: %[[CONST:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1> | ||
| // CHECK-NEXT: omp.yield(%[[CONST]] : vector<2xi1>) |
There was a problem hiding this comment.
Please could you also check the generation of the combiner region and add a CHECK-NOT for the atomic region.
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
|
@tblah Thanks for the review! I have updated the PR with the requested changes |
tblah
left a comment
There was a problem hiding this comment.
LGTM if the CI passes, thanks for the update!
Will you need help merging?
Yes, how can we merge ? |
4b8e065 to
56c1bcb
Compare
|
All pre-commit CI checks must pass before a patch can be merged. In this case the patch formatting is incorrect (you can install and use There are also build/test failures. See https://github.com/llvm/llvm-project/actions/runs/20921937013/job/60109518436?pr=173978 |
Head branch was pushed to by a user without write access
2186488 to
0a1fb08
Compare
0a1fb08 to
8fce6cb
Compare
|
i have tested it again, and passed it through clang-format check also.... can you confirm and run the CI workflows |
|
@Aniketsingh54 Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
|
/cherry-pick e259175 |
|
Failed to cherry-pick: e259175 Please manually backport the fix and push it to your github fork. Once this is done, please create a pull request |
|
Failed to cherry-pick: e259175 Please manually backport the fix and push it to your github fork. Once this is done, please create a pull request |
|
/cherry-pick e259175 |
|
Failed to cherry-pick: e259175 Please manually backport the fix and push it to your github fork. Once this is done, please create a pull request |
…73978) This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions. - Extracts scalar element types for bitwidth calculations. - Uses DenseElementsAttr for vector splat initializers. - Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR). Fixes llvm#173860 --------- Co-authored-by: Aniket Singh <[email protected]>
|
Backport PR #175933 |
This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions. - Extracts scalar element types for bitwidth calculations. - Uses DenseElementsAttr for vector splat initializers. - Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR). Fixes #173860 --------- Co-authored-by: Aniket Singh <[email protected]>
…73978) This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions. - Extracts scalar element types for bitwidth calculations. - Uses DenseElementsAttr for vector splat initializers. - Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR). Fixes llvm#173860 --------- Co-authored-by: Aniket Singh <[email protected]>
This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions.
Fixes #173860