Skip to content

Commit 55111e8

Browse files
authored
[flang] use fir.bitcast for FIRToMemRef scalar reinterpretation (#188328)
Use fir.bitcast in FIR-to-MemRef casts so bit patterns are preserved (e.g. TRANSFER), while keeping fir.convert for memref/reference marshaling and non-bitcast-compatible cases.
1 parent 2c0e63c commit 55111e8

File tree

2 files changed

+49
-11
lines changed

2 files changed

+49
-11
lines changed

flang/lib/Optimizer/Transforms/FIRToMemRef.cpp

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,46 @@ static Value castTypeToIndexType(Value originalValue,
383383
originalValue);
384384
}
385385

386+
static bool shouldUseBoundaryBitcast(mlir::Type fromTy, mlir::Type toTy) {
387+
auto isBitcastCompatibleScalarType = [](mlir::Type ty) {
388+
return mlir::isa<mlir::IntegerType, mlir::FloatType, fir::LogicalType>(
389+
ty) ||
390+
(mlir::isa<fir::CharacterType>(ty) &&
391+
mlir::cast<fir::CharacterType>(ty).getLen() ==
392+
fir::CharacterType::singleton());
393+
};
394+
auto getKnownScalarBitWidth = [](mlir::Type ty) -> std::optional<unsigned> {
395+
if (auto intTy = mlir::dyn_cast<mlir::IntegerType>(ty))
396+
return intTy.getWidth();
397+
if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty))
398+
return floatTy.getWidth();
399+
return std::nullopt;
400+
};
401+
402+
if (fromTy == toTy)
403+
return false;
404+
const bool fromStd = fir::isa_std_type(fromTy);
405+
const bool toStd = fir::isa_std_type(toTy);
406+
if (fromStd == toStd)
407+
return false;
408+
if (!isBitcastCompatibleScalarType(fromTy) ||
409+
!isBitcastCompatibleScalarType(toTy))
410+
return false;
411+
auto fromBits = getKnownScalarBitWidth(fromTy);
412+
auto toBits = getKnownScalarBitWidth(toTy);
413+
if (fromBits && toBits && *fromBits != *toBits)
414+
return false;
415+
return true;
416+
}
417+
418+
static mlir::Value createTypeConversion(PatternRewriter &rewriter,
419+
mlir::Location loc, mlir::Type toTy,
420+
mlir::Value value) {
421+
if (shouldUseBoundaryBitcast(value.getType(), toTy))
422+
return fir::BitcastOp::create(rewriter, loc, toTy, value);
423+
return fir::ConvertOp::create(rewriter, loc, toTy, value);
424+
}
425+
386426
FailureOr<SmallVector<Value>>
387427
FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
388428
PatternRewriter &rewriter, Value converted,
@@ -983,11 +1023,10 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
9831023
LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n";
9841024
loadOp.dump(); assert(succeeded(verify(loadOp))));
9851025

986-
if (isa<fir::LogicalType>(originalType)) {
987-
Value logicalVal =
988-
fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp);
989-
loadOp.getResult().replaceAllUsesExcept(logicalVal,
990-
logicalVal.getDefiningOp());
1026+
if (loadOp.getType() != originalType) {
1027+
Value castVal =
1028+
createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
1029+
loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
9911030
}
9921031

9931032
if (!isa<fir::LogicalType>(originalType))
@@ -1019,11 +1058,10 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
10191058
Value value = store.getValue();
10201059
rewriter.setInsertionPointAfter(store);
10211060

1022-
if (isa<fir::LogicalType>(value.getType())) {
1023-
Type convertedType = typeConverter.convertType(value.getType());
1061+
Type convertedType = typeConverter.convertType(value.getType());
1062+
if (convertedType != value.getType())
10241063
value =
1025-
fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value);
1026-
}
1064+
createTypeConversion(rewriter, store.getLoc(), convertedType, value);
10271065

10281066
Attribute attr = (store.getOperation())->getAttr("tbaa");
10291067
memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(

flang/test/Transforms/FIRToMemRef/logical.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// CHECK-NEXT: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
55
// CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
66
// CHECK-NEXT: [[LOAD:%[0-9]+]] = memref.load [[CONVERT]][] : memref<i32>
7-
// CHECK-NEXT: fir.convert [[LOAD]] : (i32) -> !fir.logical<4>
7+
// CHECK-NEXT: fir.bitcast [[LOAD]] : (i32) -> !fir.logical<4>
88
func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
99
%0 = fir.undefined !fir.dscope
1010
%1 = fir.declare %arg0 dummy_scope %0 {uniq_name = "a"} : (!fir.ref<!fir.logical<4>>, !fir.dscope) -> !fir.ref<!fir.logical<4>>
@@ -18,7 +18,7 @@ func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
1818
// CHECK: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
1919
// CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[CONSTTRUE]] : (i1) -> !fir.logical<4>
2020
// CHECK-NEXT: [[CONVERT1:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
21-
// CHECK-NEXT: [[INT:%[0-9]+]] = fir.convert [[CONVERT]] : (!fir.logical<4>) -> i32
21+
// CHECK-NEXT: [[INT:%[0-9]+]] = fir.bitcast [[CONVERT]] : (!fir.logical<4>) -> i32
2222
// CHECK-NEXT: memref.store [[INT]], [[CONVERT1]][] : memref<i32>
2323
func.func @store_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
2424
%true = arith.constant true

0 commit comments

Comments
 (0)