@@ -150,32 +150,48 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
150150 llvm_unreachable (" unknown float type" );
151151}
152152
153+ // / Helper to create a splat attribute for vector types, or return the scalar
154+ // / attribute for scalar types.
155+ static Attribute getSplatOrScalarAttr (Type type, Attribute val) {
156+ if (auto vecType = dyn_cast<VectorType>(type))
157+ return DenseElementsAttr::get (vecType, val);
158+ return val;
159+ }
160+
153161// / Returns an attribute with the minimum (if `min` is set) or the maximum value
154162// / (otherwise) for the given float type.
155163static Attribute minMaxValueForFloat (Type type, bool min) {
156- auto fltType = cast<FloatType>(type);
157- return FloatAttr::get (
158- type, llvm::APFloat::getLargest (fltSemanticsForType (fltType), min));
164+ Type elType = getElementTypeOrSelf (type);
165+ auto fltType = cast<FloatType>(elType);
166+ auto val = llvm::APFloat::getLargest (fltSemanticsForType (fltType), min);
167+
168+ return getSplatOrScalarAttr (type, FloatAttr::get (elType, val));
159169}
160170
161171// / Returns an attribute with the signed integer minimum (if `min` is set) or
162172// / the maximum value (otherwise) for the given integer type, regardless of its
163173// / signedness semantics (only the width is considered).
164174static Attribute minMaxValueForSignedInt (Type type, bool min) {
165- auto intType = cast<IntegerType>(type);
175+ Type elType = getElementTypeOrSelf (type);
176+ auto intType = cast<IntegerType>(elType);
166177 unsigned bitwidth = intType.getWidth ();
167- return IntegerAttr::get (type, min ? llvm::APInt::getSignedMinValue (bitwidth)
168- : llvm::APInt::getSignedMaxValue (bitwidth));
178+ auto val = min ? llvm::APInt::getSignedMinValue (bitwidth)
179+ : llvm::APInt::getSignedMaxValue (bitwidth);
180+
181+ return getSplatOrScalarAttr (type, IntegerAttr::get (elType, val));
169182}
170183
171184// / Returns an attribute with the unsigned integer minimum (if `min` is set) or
172185// / the maximum value (otherwise) for the given integer type, regardless of its
173186// / signedness semantics (only the width is considered).
174187static Attribute minMaxValueForUnsignedInt (Type type, bool min) {
175- auto intType = cast<IntegerType>(type);
188+ Type elType = getElementTypeOrSelf (type);
189+ auto intType = cast<IntegerType>(elType);
176190 unsigned bitwidth = intType.getWidth ();
177- return IntegerAttr::get (type, min ? llvm::APInt::getZero (bitwidth)
178- : llvm::APInt::getAllOnes (bitwidth));
191+ auto val =
192+ min ? llvm::APInt::getZero (bitwidth) : llvm::APInt::getAllOnes (bitwidth);
193+
194+ return getSplatOrScalarAttr (type, IntegerAttr::get (elType, val));
179195}
180196
181197// / Creates an OpenMP reduction declaration and inserts it into the provided
@@ -203,7 +219,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
203219 Operation *terminator =
204220 &reduce.getReductions ()[reductionIndex].front ().back ();
205221 assert (isa<scf::ReduceReturnOp>(terminator) &&
206- " expected reduce op to be terminated by redure return" );
222+ " expected reduce op to be terminated by reduce return" );
207223 builder.setInsertionPoint (terminator);
208224 builder.replaceOpWithNewOp <omp::YieldOp>(terminator,
209225 terminator->getOperands ());
@@ -237,6 +253,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
237253 return decl;
238254}
239255
256+ // / Returns true if the type is supported by llvm.atomicrmw.
257+ // / LLVM IR currently does not support atomic operations on vector types.
258+ // / See LLVM Language Reference Manual on 'atomicrmw'.
259+ static bool supportsAtomic (Type type) { return !isa<VectorType>(type); }
260+
240261// / Creates an OpenMP reduction declaration that corresponds to the given SCF
241262// / reduction and returns it. Recognizes common reductions in order to identify
242263// / the neutral value, necessary for the OpenMP declaration. If the reduction
@@ -261,91 +282,119 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
261282 // Match simple binary reductions that can be expressed with atomicrmw.
262283 Type type = reduce.getOperands ()[reductionIndex].getType ();
263284 Block &reduction = reduce.getReductions ()[reductionIndex].front ();
285+
286+ // Handle scalar element type extraction for vector bitwidth safety.
287+ Type elType = getElementTypeOrSelf (type);
288+
289+ // Arithmetic Reductions
264290 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
265- omp::DeclareReductionOp decl =
266- createDecl (builder, symbolTable, reduce, reductionIndex,
267- builder.getFloatAttr (type, 0.0 ));
268- return addAtomicRMW (builder, LLVM::AtomicBinOp::fadd, decl, reduce,
269- reductionIndex);
291+ omp::DeclareReductionOp decl = createDecl (
292+ builder, symbolTable, reduce, reductionIndex,
293+ getSplatOrScalarAttr (type, builder.getFloatAttr (elType, 0.0 )));
294+ return supportsAtomic (type) ? addAtomicRMW (builder, LLVM::AtomicBinOp::fadd,
295+ decl, reduce, reductionIndex)
296+ : decl;
270297 }
271298 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
272- omp::DeclareReductionOp decl =
273- createDecl (builder, symbolTable, reduce, reductionIndex,
274- builder.getIntegerAttr (type, 0 ));
275- return addAtomicRMW (builder, LLVM::AtomicBinOp::add, decl, reduce,
276- reductionIndex);
299+ omp::DeclareReductionOp decl = createDecl (
300+ builder, symbolTable, reduce, reductionIndex,
301+ getSplatOrScalarAttr (type, builder.getIntegerAttr (elType, 0 )));
302+ return supportsAtomic (type) ? addAtomicRMW (builder, LLVM::AtomicBinOp::add,
303+ decl, reduce, reductionIndex)
304+ : decl;
277305 }
278306 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
279- omp::DeclareReductionOp decl =
280- createDecl (builder, symbolTable, reduce, reductionIndex,
281- builder.getIntegerAttr (type, 0 ));
282- return addAtomicRMW (builder, LLVM::AtomicBinOp::_or, decl, reduce,
283- reductionIndex);
307+ omp::DeclareReductionOp decl = createDecl (
308+ builder, symbolTable, reduce, reductionIndex,
309+ getSplatOrScalarAttr (type, builder.getIntegerAttr (elType, 0 )));
310+ return supportsAtomic (type) ? addAtomicRMW (builder, LLVM::AtomicBinOp::_or,
311+ decl, reduce, reductionIndex)
312+ : decl;
284313 }
285314 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
286- omp::DeclareReductionOp decl =
287- createDecl (builder, symbolTable, reduce, reductionIndex,
288- builder.getIntegerAttr (type, 0 ));
289- return addAtomicRMW (builder, LLVM::AtomicBinOp::_xor, decl, reduce,
290- reductionIndex);
315+ omp::DeclareReductionOp decl = createDecl (
316+ builder, symbolTable, reduce, reductionIndex,
317+ getSplatOrScalarAttr (type, builder.getIntegerAttr (elType, 0 )));
318+ return supportsAtomic (type) ? addAtomicRMW (builder, LLVM::AtomicBinOp::_xor,
319+ decl, reduce, reductionIndex)
320+ : decl;
291321 }
292322 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
323+ APInt allOnes = llvm::APInt::getAllOnes (elType.getIntOrFloatBitWidth ());
293324 omp::DeclareReductionOp decl = createDecl (
294325 builder, symbolTable, reduce, reductionIndex,
295- builder.getIntegerAttr (
296- type, llvm::APInt::getAllOnes (type. getIntOrFloatBitWidth ())));
297- return addAtomicRMW (builder, LLVM::AtomicBinOp::_and, decl, reduce,
298- reductionIndex) ;
326+ getSplatOrScalarAttr (type, builder.getIntegerAttr (elType, allOnes)));
327+ return supportsAtomic (type) ? addAtomicRMW (builder, LLVM::AtomicBinOp::_and,
328+ decl, reduce, reductionIndex)
329+ : decl ;
299330 }
300331
301332 // Match simple binary reductions that cannot be expressed with atomicrmw.
302333 // TODO: add atomic region using cmpxchg (which needs atomic load to be
303334 // available as an op).
304335 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) {
305- return createDecl (builder, symbolTable, reduce, reductionIndex,
306- builder.getFloatAttr (type, 1.0 ));
336+ return createDecl (
337+ builder, symbolTable, reduce, reductionIndex,
338+ getSplatOrScalarAttr (type, builder.getFloatAttr (elType, 1.0 )));
307339 }
340+
308341 if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) {
309- return createDecl (builder, symbolTable, reduce, reductionIndex,
310- builder.getIntegerAttr (type, 1 ));
342+ return createDecl (
343+ builder, symbolTable, reduce, reductionIndex,
344+ getSplatOrScalarAttr (type, builder.getIntegerAttr (elType, 1 )));
311345 }
312346
313347 // Match select-based min/max reductions.
314348 bool isMin;
315- if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
349+ // Floating Point Min/Max
350+ if (matchSelectReduction<arith::CmpFOp, arith::SelectOp,
351+ arith::CmpFPredicate>(
316352 reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
317353 {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
318- matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
319- reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
320- {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
354+ matchSelectReduction<arith::CmpFOp, arith::SelectOp,
355+ arith::CmpFPredicate>(
356+ reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE},
357+ {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) {
321358 return createDecl (builder, symbolTable, reduce, reductionIndex,
322359 minMaxValueForFloat (type, !isMin));
323360 }
324- if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
361+
362+ // Integer Min/Max
363+ if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
364+ arith::CmpIPredicate>(
325365 reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
326366 {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
327- matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
328- reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
329- {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
367+ matchSelectReduction<arith::CmpIOp, arith::SelectOp,
368+ arith::CmpIPredicate>(
369+ reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge},
370+ {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) {
330371 omp::DeclareReductionOp decl =
331372 createDecl (builder, symbolTable, reduce, reductionIndex,
332373 minMaxValueForSignedInt (type, !isMin));
333- return addAtomicRMW (builder,
334- isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
335- decl, reduce, reductionIndex);
374+ return supportsAtomic (type) ? addAtomicRMW (builder,
375+ isMin ? LLVM::AtomicBinOp::min
376+ : LLVM::AtomicBinOp::max,
377+ decl, reduce, reductionIndex)
378+ : decl;
336379 }
337- if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
380+
381+ // Unsigned Integer Min/Max
382+ if (matchSelectReduction<arith::CmpIOp, arith::SelectOp,
383+ arith::CmpIPredicate>(
338384 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
339385 {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
340- matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
341- reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
342- {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
386+ matchSelectReduction<arith::CmpIOp, arith::SelectOp,
387+ arith::CmpIPredicate>(
388+ reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge},
389+ {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) {
343390 omp::DeclareReductionOp decl =
344391 createDecl (builder, symbolTable, reduce, reductionIndex,
345392 minMaxValueForUnsignedInt (type, !isMin));
346- return addAtomicRMW (
347- builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
348- decl, reduce, reductionIndex);
393+ return supportsAtomic (type) ? addAtomicRMW (builder,
394+ isMin ? LLVM::AtomicBinOp::umin
395+ : LLVM::AtomicBinOp::umax,
396+ decl, reduce, reductionIndex)
397+ : decl;
349398 }
350399
351400 return nullptr ;
0 commit comments