Skip to content

Commit 797916b

Browse files
authored
[OpenMP][flang] Fix crash in host offload (#187847)
Guard `getGridValue` in `OMPIRBuilder` to avoid reaching the `unreachable` in `getGridValue` when offloading to host device without an explicit num_threads clause.
1 parent 1422665 commit 797916b

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ static void restoreIPandDebugLoc(llvm::IRBuilderBase &Builder,
169169
Builder.SetCurrentDebugLocation(BB->back().getStableDebugLoc());
170170
}
171171

172+
static bool hasGridValue(const Triple &T) {
173+
return T.isAMDGPU() || T.isNVPTX() || T.isSPIRV();
174+
}
175+
172176
static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
173177
if (T.isAMDGPU()) {
174178
StringRef Features =
@@ -7773,9 +7777,15 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
77737777
// If MaxThreads not set, select the maximum between the default workgroup
77747778
// size and the MinThreads value.
77757779
int32_t MaxThreadsVal = Attrs.MaxThreads.front();
7776-
if (MaxThreadsVal < 0)
7777-
MaxThreadsVal = std::max(
7778-
int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
7780+
if (MaxThreadsVal < 0) {
7781+
if (hasGridValue(T)) {
7782+
MaxThreadsVal =
7783+
std::max(int32_t(getGridValue(T, Kernel).GV_Default_WG_Size),
7784+
Attrs.MinThreads);
7785+
} else {
7786+
MaxThreadsVal = Attrs.MinThreads;
7787+
}
7788+
}
77797789

77807790
if (MaxThreadsVal > 0)
77817791
writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// Verify that host offloading doesn't crash the OMPIRBuilder.
4+
module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu", omp.is_target_device = true} {
5+
llvm.func @omp_target_region_host_device() {
6+
omp.target {
7+
omp.terminator
8+
}
9+
llvm.return
10+
}
11+
}
12+
13+
// CHECK: define void @omp_target_region_host_device()
14+
// CHECK: define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region_host_device_l{{[0-9]+}}(ptr %[[ADDR_A:.*]])

0 commit comments

Comments
 (0)