Skip to content

Commit 7fde02e

Browse files
StrophoxRalfJung
andcommitted
enable Miri to pass const pointers through FFI
Co-authored-by: Ralf Jung <[email protected]>
1 parent 748c548 commit 7fde02e

File tree

15 files changed

+260
-17
lines changed

15 files changed

+260
-17
lines changed

compiler/rustc_const_eval/src/interpret/machine.rs

+1-11
Original file line numberDiff line numberDiff line change
@@ -357,17 +357,7 @@ pub trait Machine<'tcx>: Sized {
357357
ecx: &InterpCx<'tcx, Self>,
358358
id: AllocId,
359359
alloc: &'b Allocation,
360-
) -> InterpResult<'tcx, Cow<'b, Allocation<Self::Provenance, Self::AllocExtra, Self::Bytes>>>
361-
{
362-
// The default implementation does a copy; CTFE machines have a more efficient implementation
363-
// based on their particular choice for `Provenance`, `AllocExtra`, and `Bytes`.
364-
let kind = Self::GLOBAL_KIND
365-
.expect("if GLOBAL_KIND is None, adjust_global_allocation must be overwritten");
366-
let alloc = alloc.adjust_from_tcx(&ecx.tcx, |ptr| ecx.global_root_pointer(ptr))?;
367-
let extra =
368-
Self::init_alloc_extra(ecx, id, MemoryKind::Machine(kind), alloc.size(), alloc.align)?;
369-
Ok(Cow::Owned(alloc.with_extra(extra)))
370-
}
360+
) -> InterpResult<'tcx, Cow<'b, Allocation<Self::Provenance, Self::AllocExtra, Self::Bytes>>>;
371361

372362
/// Initialize the extra state of an allocation.
373363
///

compiler/rustc_middle/src/mir/interpret/allocation.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,11 @@ impl Allocation {
358358
pub fn adjust_from_tcx<Prov: Provenance, Bytes: AllocBytes, Err>(
359359
&self,
360360
cx: &impl HasDataLayout,
361+
mut alloc_bytes: impl FnMut(&[u8], Align) -> Result<Bytes, Err>,
361362
mut adjust_ptr: impl FnMut(Pointer<CtfeProvenance>) -> Result<Pointer<Prov>, Err>,
362363
) -> Result<Allocation<Prov, (), Bytes>, Err> {
363364
// Copy the data.
364-
let mut bytes = Bytes::from_bytes(Cow::Borrowed(&*self.bytes), self.align);
365+
let mut bytes = alloc_bytes(&*self.bytes, self.align)?;
365366
// Adjust provenance of pointers stored in this allocation.
366367
let mut new_provenance = Vec::with_capacity(self.provenance.ptrs().len());
367368
let ptr_size = cx.data_layout().pointer_size.bytes_usize();

src/tools/miri/src/alloc_addresses/mod.rs

+67-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ pub struct GlobalStateInner {
4242
/// they do not have an `AllocExtra`.
4343
/// This is the inverse of `int_to_ptr_map`.
4444
base_addr: FxHashMap<AllocId, u64>,
45+
/// Temporarily store prepared memory space for global allocations the first time their memory
46+
/// address is required. This is used to ensure that the memory is allocated before Miri assigns
47+
/// it an internal address, which is important for matching the internal address to the machine
48+
/// address so FFI can read from pointers.
49+
prepared_alloc_bytes: FxHashMap<AllocId, MiriAllocBytes>,
4550
/// A pool of addresses we can reuse for future allocations.
4651
reuse: ReusePool,
4752
/// Whether an allocation has been exposed or not. This cannot be put
@@ -59,6 +64,7 @@ impl VisitProvenance for GlobalStateInner {
5964
let GlobalStateInner {
6065
int_to_ptr_map: _,
6166
base_addr: _,
67+
prepared_alloc_bytes: _,
6268
reuse: _,
6369
exposed: _,
6470
next_base_addr: _,
@@ -78,6 +84,7 @@ impl GlobalStateInner {
7884
GlobalStateInner {
7985
int_to_ptr_map: Vec::default(),
8086
base_addr: FxHashMap::default(),
87+
prepared_alloc_bytes: FxHashMap::default(),
8188
reuse: ReusePool::new(config),
8289
exposed: FxHashSet::default(),
8390
next_base_addr: stack_addr,
@@ -166,7 +173,39 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
166173
assert!(!matches!(kind, AllocKind::Dead));
167174

168175
// This allocation does not have a base address yet, pick or reuse one.
169-
let base_addr = if let Some((reuse_addr, clock)) = global_state.reuse.take_addr(
176+
let base_addr = if ecx.machine.native_lib.is_some() {
177+
// In native lib mode, we use the "real" address of the bytes for this allocation.
178+
// This ensures the interpreted program and native code have the same view of memory.
179+
match kind {
180+
AllocKind::LiveData => {
181+
let ptr = if ecx.tcx.try_get_global_alloc(alloc_id).is_some() {
182+
// For new global allocations, we always pre-allocate the memory to be able use the machine address directly.
183+
let prepared_bytes = MiriAllocBytes::zeroed(size, align)
184+
.unwrap_or_else(|| {
185+
panic!("Miri ran out of memory: cannot create allocation of {size:?} bytes")
186+
});
187+
let ptr = prepared_bytes.as_ptr();
188+
// Store prepared allocation space to be picked up for use later.
189+
global_state.prepared_alloc_bytes.try_insert(alloc_id, prepared_bytes).unwrap();
190+
ptr
191+
} else {
192+
ecx.get_alloc_bytes_unchecked_raw(alloc_id)?
193+
};
194+
// Ensure this pointer's provenance is exposed, so that it can be used by FFI code.
195+
ptr.expose_provenance().try_into().unwrap()
196+
}
197+
AllocKind::Function | AllocKind::VTable => {
198+
// Allocate some dummy memory to get a unique address for this function/vtable.
199+
let alloc_bytes = MiriAllocBytes::from_bytes(&[0u8; 1], Align::from_bytes(1).unwrap());
200+
// We don't need to expose these bytes as nobody is allowed to access them.
201+
let addr = alloc_bytes.as_ptr().addr().try_into().unwrap();
202+
// Leak the underlying memory to ensure it remains unique.
203+
std::mem::forget(alloc_bytes);
204+
addr
205+
}
206+
AllocKind::Dead => unreachable!()
207+
}
208+
} else if let Some((reuse_addr, clock)) = global_state.reuse.take_addr(
170209
&mut *rng,
171210
size,
172211
align,
@@ -318,6 +357,33 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
318357
Ok(base_ptr.wrapping_offset(offset, ecx))
319358
}
320359

360+
// This returns some prepared `MiriAllocBytes`, either because `addr_from_alloc_id` reserved
361+
// memory space in the past, or by doing the pre-allocation right upon being called.
362+
fn get_global_alloc_bytes(&self, id: AllocId, kind: MemoryKind, bytes: &[u8], align: Align) -> InterpResult<'tcx, MiriAllocBytes> {
363+
let ecx = self.eval_context_ref();
364+
Ok(if ecx.machine.native_lib.is_some() {
365+
// In native lib mode, MiriAllocBytes for global allocations are handled via `prepared_alloc_bytes`.
366+
// This additional call ensures that some `MiriAllocBytes` are always prepared.
367+
ecx.addr_from_alloc_id(id, kind)?;
368+
let mut global_state = ecx.machine.alloc_addresses.borrow_mut();
369+
// The memory we need here will have already been allocated during an earlier call to
370+
// `addr_from_alloc_id` for this allocation. So don't create a new `MiriAllocBytes` here, instead
371+
// fetch the previously prepared bytes from `prepared_alloc_bytes`.
372+
let mut prepared_alloc_bytes = global_state
373+
.prepared_alloc_bytes
374+
.remove(&id)
375+
.unwrap_or_else(|| panic!("alloc bytes for {id:?} have not been prepared"));
376+
// Sanity-check that the prepared allocation has the right size and alignment.
377+
assert!(prepared_alloc_bytes.as_ptr().is_aligned_to(align.bytes_usize()));
378+
assert_eq!(prepared_alloc_bytes.len(), bytes.len());
379+
// Copy allocation contents into prepared memory.
380+
prepared_alloc_bytes.copy_from_slice(bytes);
381+
prepared_alloc_bytes
382+
} else {
383+
MiriAllocBytes::from_bytes(std::borrow::Cow::Borrowed(&*bytes), align)
384+
})
385+
}
386+
321387
/// When a pointer is used for a memory access, this computes where in which allocation the
322388
/// access is going.
323389
fn ptr_get_alloc(

src/tools/miri/src/concurrency/thread.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
887887
let alloc = this.ctfe_query(|tcx| tcx.eval_static_initializer(def_id))?;
888888
// We make a full copy of this allocation.
889889
let mut alloc =
890-
alloc.inner().adjust_from_tcx(&this.tcx, |ptr| this.global_root_pointer(ptr))?;
890+
alloc.inner().adjust_from_tcx(&this.tcx, |bytes, align| Ok(MiriAllocBytes::from_bytes(std::borrow::Cow::Borrowed(bytes), align)), |ptr| this.global_root_pointer(ptr))?;
891891
// This allocation will be deallocated when the thread dies, so it is not in read-only memory.
892892
alloc.mutability = Mutability::Mut;
893893
// Create a fresh allocation with this content.

src/tools/miri/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#![feature(let_chains)]
1313
#![feature(trait_upcasting)]
1414
#![feature(strict_overflow_ops)]
15+
#![feature(strict_provenance)]
16+
#![feature(exposed_provenance)]
17+
#![feature(pointer_is_aligned_to)]
1518
// Configure clippy and other lints
1619
#![allow(
1720
clippy::collapsible_else_if,

src/tools/miri/src/machine.rs

+25
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Global machine state as well as implementation of the interpreter engine
22
//! `Machine` trait.
33
4+
use std::borrow::Cow;
45
use std::cell::RefCell;
56
use std::collections::hash_map::Entry;
67
use std::fmt;
@@ -1225,6 +1226,30 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
12251226
})
12261227
}
12271228

1229+
/// Called to adjust global allocations to the Provenance and AllocExtra of this machine.
1230+
///
1231+
/// If `alloc` contains pointers, then they are all pointing to globals.
1232+
///
1233+
/// This should avoid copying if no work has to be done! If this returns an owned
1234+
/// allocation (because a copy had to be done to adjust things), machine memory will
1235+
/// cache the result. (This relies on `AllocMap::get_or` being able to add the
1236+
/// owned allocation to the map even when the map is shared.)
1237+
fn adjust_global_allocation<'b>(
1238+
ecx: &InterpCx<'tcx, Self>,
1239+
id: AllocId,
1240+
alloc: &'b Allocation,
1241+
) -> InterpResult<'tcx, Cow<'b, Allocation<Self::Provenance, Self::AllocExtra, Self::Bytes>>>
1242+
{
1243+
let kind = Self::GLOBAL_KIND.unwrap().into();
1244+
let alloc = alloc.adjust_from_tcx(&ecx.tcx,
1245+
|bytes, align| ecx.get_global_alloc_bytes(id, kind, bytes, align),
1246+
|ptr| ecx.global_root_pointer(ptr),
1247+
)?;
1248+
let extra =
1249+
Self::init_alloc_extra(ecx, id, kind, alloc.size(), alloc.align)?;
1250+
Ok(Cow::Owned(alloc.with_extra(extra)))
1251+
}
1252+
12281253
#[inline(always)]
12291254
fn before_memory_read(
12301255
_tcx: TyCtxtAt<'tcx>,

src/tools/miri/src/shims/native_lib.rs

+13
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ enum CArg {
194194
UInt64(u64),
195195
/// usize.
196196
USize(usize),
197+
/// Raw pointer, stored as C's `void*`.
198+
RawPtr(*mut std::ffi::c_void),
197199
}
198200

199201
impl<'a> CArg {
@@ -210,6 +212,7 @@ impl<'a> CArg {
210212
CArg::UInt32(i) => ffi::arg(i),
211213
CArg::UInt64(i) => ffi::arg(i),
212214
CArg::USize(i) => ffi::arg(i),
215+
CArg::RawPtr(i) => ffi::arg(i),
213216
}
214217
}
215218
}
@@ -234,6 +237,16 @@ fn imm_to_carg<'tcx>(v: ImmTy<'tcx>, cx: &impl HasDataLayout) -> InterpResult<'t
234237
ty::Uint(UintTy::U64) => CArg::UInt64(v.to_scalar().to_u64()?),
235238
ty::Uint(UintTy::Usize) =>
236239
CArg::USize(v.to_scalar().to_target_usize(cx)?.try_into().unwrap()),
240+
ty::RawPtr(_, mutability) => {
241+
// Arbitrary mutable pointer accesses are not currently supported in Miri.
242+
if mutability.is_mut() {
243+
throw_unsup_format!("unsupported mutable pointer type for native call: {}", v.layout.ty);
244+
} else {
245+
let s = v.to_scalar().to_pointer(cx)?.addr();
246+
// This relies on the `expose_provenance` in `addr_from_alloc_id`.
247+
CArg::RawPtr(std::ptr::with_exposed_provenance_mut(s.bytes_usize()))
248+
}
249+
},
237250
_ => throw_unsup_format!("unsupported argument type for native call: {}", v.layout.ty),
238251
})
239252
}
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
CODEABI_1.0 {
22
# Define which symbols to export.
33
global:
4+
# scalar_arguments.c
45
add_one_int;
56
printer;
67
test_stack_spill;
78
get_unsigned_int;
89
add_int16;
910
add_short_to_long;
11+
12+
# ptr_read_access.c
13+
print_pointer;
14+
access_simple;
15+
access_nested;
16+
access_static;
17+
1018
# The rest remains private.
1119
local: *;
1220
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//@only-target-linux
2+
//@only-on-host
3+
4+
fn main() {
5+
test_pointer();
6+
7+
test_simple();
8+
9+
test_nested();
10+
11+
test_static();
12+
}
13+
14+
// Test void function that dereferences a pointer and prints its contents from C.
15+
fn test_pointer() {
16+
extern "C" {
17+
fn print_pointer(ptr: *const i32);
18+
}
19+
20+
let x = 42;
21+
22+
unsafe { print_pointer(&x) };
23+
}
24+
25+
// Test function that dereferences a simple struct pointer and accesses a field.
26+
fn test_simple() {
27+
#[repr(C)]
28+
struct Simple {
29+
field: i32
30+
}
31+
32+
extern "C" {
33+
fn access_simple(s_ptr: *const Simple) -> i32;
34+
}
35+
36+
let simple = Simple { field: -42 };
37+
38+
assert_eq!(unsafe { access_simple(&simple) }, -42);
39+
}
40+
41+
// Test function that dereferences nested struct pointers and accesses fields.
42+
fn test_nested() {
43+
use std::ptr::NonNull;
44+
45+
#[derive(Debug, PartialEq, Eq)]
46+
#[repr(C)]
47+
struct Nested {
48+
value: i32,
49+
next: Option<NonNull<Nested>>,
50+
}
51+
52+
extern "C" {
53+
fn access_nested(n_ptr: *const Nested) -> i32;
54+
}
55+
56+
let mut nested_0 = Nested { value: 97, next: None };
57+
let mut nested_1 = Nested { value: 98, next: NonNull::new(&mut nested_0) };
58+
let nested_2 = Nested { value: 99, next: NonNull::new(&mut nested_1) };
59+
60+
assert_eq!(unsafe { access_nested(&nested_2) }, 97);
61+
}
62+
63+
// Test function that dereferences static struct pointers and accesses fields.
64+
fn test_static() {
65+
66+
#[repr(C)]
67+
struct Static {
68+
value: i32,
69+
recurse: &'static Static,
70+
}
71+
72+
extern "C" {
73+
fn access_static(n_ptr: *const Static) -> i32;
74+
}
75+
76+
static STATIC: Static = Static {
77+
value: 9001,
78+
recurse: &STATIC,
79+
};
80+
81+
assert_eq!(unsafe { access_static(&STATIC) }, 9001);
82+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
printing pointer dereference from C: 42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include <stdio.h>
2+
3+
/* Test: test_pointer */
4+
5+
void print_pointer(const int *ptr) {
6+
printf("printing pointer dereference from C: %d\n", *ptr);
7+
}
8+
9+
/* Test: test_simple */
10+
11+
typedef struct Simple {
12+
int field;
13+
} Simple;
14+
15+
int access_simple(const Simple *s_ptr) {
16+
return s_ptr->field;
17+
}
18+
19+
/* Test: test_nested */
20+
21+
typedef struct Nested {
22+
int value;
23+
struct Nested *next;
24+
} Nested;
25+
26+
// Returns the innermost/last value of a Nested pointer chain.
27+
int access_nested(const Nested *n_ptr) {
28+
// Edge case: `n_ptr == NULL` (i.e. first Nested is None).
29+
if (!n_ptr) { return 0; }
30+
31+
while (n_ptr->next) {
32+
n_ptr = n_ptr->next;
33+
}
34+
35+
return n_ptr->value;
36+
}
37+
38+
/* Test: test_static */
39+
40+
typedef struct Static {
41+
int value;
42+
struct Static *recurse;
43+
} Static;
44+
45+
int access_static(const Static *s_ptr) {
46+
return s_ptr->recurse->recurse->value;
47+
}

0 commit comments

Comments
 (0)