@@ -383,6 +383,46 @@ static Value castTypeToIndexType(Value originalValue,
383383 originalValue);
384384}
385385
386+ static bool shouldUseBoundaryBitcast (mlir::Type fromTy, mlir::Type toTy) {
387+ auto isBitcastCompatibleScalarType = [](mlir::Type ty) {
388+ return mlir::isa<mlir::IntegerType, mlir::FloatType, fir::LogicalType>(
389+ ty) ||
390+ (mlir::isa<fir::CharacterType>(ty) &&
391+ mlir::cast<fir::CharacterType>(ty).getLen () ==
392+ fir::CharacterType::singleton ());
393+ };
394+ auto getKnownScalarBitWidth = [](mlir::Type ty) -> std::optional<unsigned > {
395+ if (auto intTy = mlir::dyn_cast<mlir::IntegerType>(ty))
396+ return intTy.getWidth ();
397+ if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty))
398+ return floatTy.getWidth ();
399+ return std::nullopt ;
400+ };
401+
402+ if (fromTy == toTy)
403+ return false ;
404+ const bool fromStd = fir::isa_std_type (fromTy);
405+ const bool toStd = fir::isa_std_type (toTy);
406+ if (fromStd == toStd)
407+ return false ;
408+ if (!isBitcastCompatibleScalarType (fromTy) ||
409+ !isBitcastCompatibleScalarType (toTy))
410+ return false ;
411+ auto fromBits = getKnownScalarBitWidth (fromTy);
412+ auto toBits = getKnownScalarBitWidth (toTy);
413+ if (fromBits && toBits && *fromBits != *toBits)
414+ return false ;
415+ return true ;
416+ }
417+
418+ static mlir::Value createTypeConversion (PatternRewriter &rewriter,
419+ mlir::Location loc, mlir::Type toTy,
420+ mlir::Value value) {
421+ if (shouldUseBoundaryBitcast (value.getType (), toTy))
422+ return fir::BitcastOp::create (rewriter, loc, toTy, value);
423+ return fir::ConvertOp::create (rewriter, loc, toTy, value);
424+ }
425+
386426FailureOr<SmallVector<Value>>
387427FIRToMemRef::getMemrefIndices (fir::ArrayCoorOp arrayCoorOp, Operation *memref,
388428 PatternRewriter &rewriter, Value converted,
@@ -983,11 +1023,10 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
9831023 LLVM_DEBUG (llvm::dbgs () << " FIRToMemRef: new memref.load op:\n " ;
9841024 loadOp.dump (); assert (succeeded (verify (loadOp))));
9851025
986- if (isa<fir::LogicalType>(originalType)) {
987- Value logicalVal =
988- fir::ConvertOp::create (rewriter, loadOp.getLoc (), originalType, loadOp);
989- loadOp.getResult ().replaceAllUsesExcept (logicalVal,
990- logicalVal.getDefiningOp ());
1026+ if (loadOp.getType () != originalType) {
1027+ Value castVal =
1028+ createTypeConversion (rewriter, loadOp.getLoc (), originalType, loadOp);
1029+ loadOp.getResult ().replaceAllUsesExcept (castVal, castVal.getDefiningOp ());
9911030 }
9921031
9931032 if (!isa<fir::LogicalType>(originalType))
@@ -1019,11 +1058,10 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
10191058 Value value = store.getValue ();
10201059 rewriter.setInsertionPointAfter (store);
10211060
1022- if (isa<fir::LogicalType>( value.getType ())) {
1023- Type convertedType = typeConverter. convertType ( value.getType ());
1061+ Type convertedType = typeConverter. convertType ( value.getType ());
1062+ if ( convertedType != value.getType ())
10241063 value =
1025- fir::ConvertOp::create (rewriter, store.getLoc (), convertedType, value);
1026- }
1064+ createTypeConversion (rewriter, store.getLoc (), convertedType, value);
10271065
10281066 Attribute attr = (store.getOperation ())->getAttr (" tbaa" );
10291067 memref::StoreOp storeOp = rewriter.replaceOpWithNewOp <memref::StoreOp>(
0 commit comments