Skip to content

Commit 2e89443

Browse files
committed
add a macro to declare thread unblock callbacks
1 parent e6bb468 commit 2e89443

File tree

6 files changed

+256
-249
lines changed

6 files changed

+256
-249
lines changed

src/tools/miri/src/concurrency/init_once.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
7575
fn init_once_enqueue_and_block(
7676
&mut self,
7777
id: InitOnceId,
78-
callback: impl UnblockCallback<'mir, 'tcx> + 'tcx,
78+
callback: impl UnblockCallback<'tcx> + 'tcx,
7979
) {
8080
let this = self.eval_context_mut();
8181
let thread = this.active_thread();

src/tools/miri/src/concurrency/sync.rs

+130-171
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ macro_rules! declare_id {
3535
}
3636
}
3737

38+
impl $crate::VisitProvenance for $name {
39+
fn visit_provenance(&self, _visit: &mut VisitWith<'_>) {}
40+
}
41+
3842
impl Idx for $name {
3943
fn new(idx: usize) -> Self {
4044
// We use 0 as a sentinel value (see the comment above) and,
@@ -258,6 +262,25 @@ pub(super) trait EvalContextExtPriv<'mir, 'tcx: 'mir>:
258262
Ok(new_index)
259263
}
260264
}
265+
266+
fn condvar_reacquire_mutex(
267+
&mut self,
268+
mutex: MutexId,
269+
retval: Scalar<Provenance>,
270+
dest: MPlaceTy<'tcx, Provenance>,
271+
) -> InterpResult<'tcx> {
272+
let this = self.eval_context_mut();
273+
if this.mutex_is_locked(mutex) {
274+
assert_ne!(this.mutex_get_owner(mutex), this.active_thread());
275+
this.mutex_enqueue_and_block(mutex, retval, dest);
276+
} else {
277+
// We can have it right now!
278+
this.mutex_lock(mutex);
279+
// Don't forget to write the return value.
280+
this.write_scalar(retval, &dest)?;
281+
}
282+
Ok(())
283+
}
261284
}
262285

263286
// Public interface to synchronization primitives. Please note that in most
@@ -384,29 +407,23 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
384407
assert!(this.mutex_is_locked(id), "queing on unlocked mutex");
385408
let thread = this.active_thread();
386409
this.machine.sync.mutexes[id].queue.push_back(thread);
387-
this.block_thread(BlockReason::Mutex(id), None, Callback { id, retval, dest });
388-
389-
struct Callback<'tcx> {
390-
id: MutexId,
391-
retval: Scalar<Provenance>,
392-
dest: MPlaceTy<'tcx, Provenance>,
393-
}
394-
impl<'tcx> VisitProvenance for Callback<'tcx> {
395-
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
396-
let Callback { id: _, retval, dest } = self;
397-
retval.visit_provenance(visit);
398-
dest.visit_provenance(visit);
399-
}
400-
}
401-
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
402-
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
403-
assert!(!this.mutex_is_locked(self.id));
404-
this.mutex_lock(self.id);
405-
406-
this.write_scalar(self.retval, &self.dest)?;
407-
Ok(())
408-
}
409-
}
410+
this.block_thread(
411+
BlockReason::Mutex(id),
412+
None,
413+
callback!(
414+
@capture<'tcx> {
415+
id: MutexId,
416+
retval: Scalar<Provenance>,
417+
dest: MPlaceTy<'tcx, Provenance>,
418+
}
419+
@unblock = |this| {
420+
assert!(!this.mutex_is_locked(id));
421+
this.mutex_lock(id);
422+
this.write_scalar(retval, &dest)?;
423+
Ok(())
424+
}
425+
),
426+
);
410427
}
411428

412429
#[inline]
@@ -500,27 +517,22 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
500517
let thread = this.active_thread();
501518
assert!(this.rwlock_is_write_locked(id), "read-queueing on not write locked rwlock");
502519
this.machine.sync.rwlocks[id].reader_queue.push_back(thread);
503-
this.block_thread(BlockReason::RwLock(id), None, Callback { id, retval, dest });
504-
505-
struct Callback<'tcx> {
506-
id: RwLockId,
507-
retval: Scalar<Provenance>,
508-
dest: MPlaceTy<'tcx, Provenance>,
509-
}
510-
impl<'tcx> VisitProvenance for Callback<'tcx> {
511-
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
512-
let Callback { id: _, retval, dest } = self;
513-
retval.visit_provenance(visit);
514-
dest.visit_provenance(visit);
515-
}
516-
}
517-
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
518-
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
519-
this.rwlock_reader_lock(self.id);
520-
this.write_scalar(self.retval, &self.dest)?;
521-
Ok(())
522-
}
523-
}
520+
this.block_thread(
521+
BlockReason::RwLock(id),
522+
None,
523+
callback!(
524+
@capture<'tcx> {
525+
id: RwLockId,
526+
retval: Scalar<Provenance>,
527+
dest: MPlaceTy<'tcx, Provenance>,
528+
}
529+
@unblock = |this| {
530+
this.rwlock_reader_lock(id);
531+
this.write_scalar(retval, &dest)?;
532+
Ok(())
533+
}
534+
),
535+
);
524536
}
525537

526538
/// Lock by setting the writer that owns the lock.
@@ -588,27 +600,22 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
588600
assert!(this.rwlock_is_locked(id), "write-queueing on unlocked rwlock");
589601
let thread = this.active_thread();
590602
this.machine.sync.rwlocks[id].writer_queue.push_back(thread);
591-
this.block_thread(BlockReason::RwLock(id), None, Callback { id, retval, dest });
592-
593-
struct Callback<'tcx> {
594-
id: RwLockId,
595-
retval: Scalar<Provenance>,
596-
dest: MPlaceTy<'tcx, Provenance>,
597-
}
598-
impl<'tcx> VisitProvenance for Callback<'tcx> {
599-
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
600-
let Callback { id: _, retval, dest } = self;
601-
retval.visit_provenance(visit);
602-
dest.visit_provenance(visit);
603-
}
604-
}
605-
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
606-
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
607-
this.rwlock_writer_lock(self.id);
608-
this.write_scalar(self.retval, &self.dest)?;
609-
Ok(())
610-
}
611-
}
603+
this.block_thread(
604+
BlockReason::RwLock(id),
605+
None,
606+
callback!(
607+
@capture<'tcx> {
608+
id: RwLockId,
609+
retval: Scalar<Provenance>,
610+
dest: MPlaceTy<'tcx, Provenance>,
611+
}
612+
@unblock = |this| {
613+
this.rwlock_writer_lock(id);
614+
this.write_scalar(retval, &dest)?;
615+
Ok(())
616+
}
617+
),
618+
);
612619
}
613620

614621
/// Is the conditional variable awaited?
@@ -648,71 +655,37 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
648655
this.block_thread(
649656
BlockReason::Condvar(condvar),
650657
timeout,
651-
Callback { condvar, mutex, retval_succ, retval_timeout, dest },
652-
);
653-
return Ok(());
654-
655-
struct Callback<'tcx> {
656-
condvar: CondvarId,
657-
mutex: MutexId,
658-
retval_succ: Scalar<Provenance>,
659-
retval_timeout: Scalar<Provenance>,
660-
dest: MPlaceTy<'tcx, Provenance>,
661-
}
662-
impl<'tcx> VisitProvenance for Callback<'tcx> {
663-
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
664-
let Callback { condvar: _, mutex: _, retval_succ, retval_timeout, dest } = self;
665-
retval_succ.visit_provenance(visit);
666-
retval_timeout.visit_provenance(visit);
667-
dest.visit_provenance(visit);
668-
}
669-
}
670-
impl<'tcx, 'mir> Callback<'tcx> {
671-
#[allow(clippy::boxed_local)]
672-
fn reacquire_mutex(
673-
self: Box<Self>,
674-
this: &mut MiriInterpCx<'mir, 'tcx>,
675-
retval: Scalar<Provenance>,
676-
) -> InterpResult<'tcx> {
677-
if this.mutex_is_locked(self.mutex) {
678-
assert_ne!(this.mutex_get_owner(self.mutex), this.active_thread());
679-
this.mutex_enqueue_and_block(self.mutex, retval, self.dest);
680-
} else {
681-
// We can have it right now!
682-
this.mutex_lock(self.mutex);
683-
// Don't forget to write the return value.
684-
this.write_scalar(retval, &self.dest)?;
658+
callback!(
659+
@capture<'tcx> {
660+
condvar: CondvarId,
661+
mutex: MutexId,
662+
retval_succ: Scalar<Provenance>,
663+
retval_timeout: Scalar<Provenance>,
664+
dest: MPlaceTy<'tcx, Provenance>,
685665
}
686-
Ok(())
687-
}
688-
}
689-
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
690-
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
691-
// The condvar was signaled. Make sure we get the clock for that.
692-
if let Some(data_race) = &this.machine.data_race {
693-
data_race.acquire_clock(
694-
&this.machine.sync.condvars[self.condvar].clock,
695-
&this.machine.threads,
696-
);
666+
@unblock = |this| {
667+
// The condvar was signaled. Make sure we get the clock for that.
668+
if let Some(data_race) = &this.machine.data_race {
669+
data_race.acquire_clock(
670+
&this.machine.sync.condvars[condvar].clock,
671+
&this.machine.threads,
672+
);
673+
}
674+
// Try to acquire the mutex.
675+
// The timeout only applies to the first wait (until the signal), not for mutex acquisition.
676+
this.condvar_reacquire_mutex(mutex, retval_succ, dest)
697677
}
698-
// Try to acquire the mutex.
699-
// The timeout only applies to the first wait (until the signal), not for mutex acquisition.
700-
let retval = self.retval_succ;
701-
self.reacquire_mutex(this, retval)
702-
}
703-
fn timeout(
704-
self: Box<Self>,
705-
this: &mut InterpCx<'mir, 'tcx, MiriMachine<'mir, 'tcx>>,
706-
) -> InterpResult<'tcx> {
707-
// We have to remove the waiter from the queue again.
708-
let thread = this.active_thread();
709-
let waiters = &mut this.machine.sync.condvars[self.condvar].waiters;
710-
waiters.retain(|waiter| *waiter != thread);
711-
// Now get back the lock.
712-
let retval = self.retval_timeout;
713-
self.reacquire_mutex(this, retval)
714-
}
715-
}
678+
@timeout = |this| {
679+
// We have to remove the waiter from the queue again.
680+
let thread = this.active_thread();
681+
let waiters = &mut this.machine.sync.condvars[condvar].waiters;
682+
waiters.retain(|waiter| *waiter != thread);
683+
// Now get back the lock.
684+
this.condvar_reacquire_mutex(mutex, retval_timeout, dest)
685+
}
686+
),
687+
);
688+
return Ok(());
716689
}
717690

718691
/// Wake up some thread (if there is any) sleeping on the conditional
@@ -755,50 +728,36 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
755728
this.block_thread(
756729
BlockReason::Futex { addr },
757730
timeout,
758-
Callback { addr, retval_succ, retval_timeout, dest, errno_timeout },
759-
);
760-
761-
struct Callback<'tcx> {
762-
addr: u64,
763-
retval_succ: Scalar<Provenance>,
764-
retval_timeout: Scalar<Provenance>,
765-
dest: MPlaceTy<'tcx, Provenance>,
766-
errno_timeout: Scalar<Provenance>,
767-
}
768-
impl<'tcx> VisitProvenance for Callback<'tcx> {
769-
fn visit_provenance(&self, visit: &mut VisitWith<'_>) {
770-
let Callback { addr: _, retval_succ, retval_timeout, dest, errno_timeout } = self;
771-
retval_succ.visit_provenance(visit);
772-
retval_timeout.visit_provenance(visit);
773-
dest.visit_provenance(visit);
774-
errno_timeout.visit_provenance(visit);
775-
}
776-
}
777-
impl<'mir, 'tcx: 'mir> UnblockCallback<'mir, 'tcx> for Callback<'tcx> {
778-
fn unblock(self: Box<Self>, this: &mut MiriInterpCx<'mir, 'tcx>) -> InterpResult<'tcx> {
779-
let futex = this.machine.sync.futexes.get(&self.addr).unwrap();
780-
// Acquire the clock of the futex.
781-
if let Some(data_race) = &this.machine.data_race {
782-
data_race.acquire_clock(&futex.clock, &this.machine.threads);
731+
callback!(
732+
@capture<'tcx> {
733+
addr: u64,
734+
retval_succ: Scalar<Provenance>,
735+
retval_timeout: Scalar<Provenance>,
736+
dest: MPlaceTy<'tcx, Provenance>,
737+
errno_timeout: Scalar<Provenance>,
783738
}
784-
// Write the return value.
785-
this.write_scalar(self.retval_succ, &self.dest)?;
786-
Ok(())
787-
}
788-
fn timeout(
789-
self: Box<Self>,
790-
this: &mut InterpCx<'mir, 'tcx, MiriMachine<'mir, 'tcx>>,
791-
) -> InterpResult<'tcx> {
792-
// Remove the waiter from the futex.
793-
let thread = this.active_thread();
794-
let futex = this.machine.sync.futexes.get_mut(&self.addr).unwrap();
795-
futex.waiters.retain(|waiter| waiter.thread != thread);
796-
// Set errno and write return value.
797-
this.set_last_error(self.errno_timeout)?;
798-
this.write_scalar(self.retval_timeout, &self.dest)?;
799-
Ok(())
800-
}
801-
}
739+
@unblock = |this| {
740+
let futex = this.machine.sync.futexes.get(&addr).unwrap();
741+
// Acquire the clock of the futex.
742+
if let Some(data_race) = &this.machine.data_race {
743+
data_race.acquire_clock(&futex.clock, &this.machine.threads);
744+
}
745+
// Write the return value.
746+
this.write_scalar(retval_succ, &dest)?;
747+
Ok(())
748+
}
749+
@timeout = |this| {
750+
// Remove the waiter from the futex.
751+
let thread = this.active_thread();
752+
let futex = this.machine.sync.futexes.get_mut(&addr).unwrap();
753+
futex.waiters.retain(|waiter| waiter.thread != thread);
754+
// Set errno and write return value.
755+
this.set_last_error(errno_timeout)?;
756+
this.write_scalar(retval_timeout, &dest)?;
757+
Ok(())
758+
}
759+
),
760+
);
802761
}
803762

804763
/// Returns whether anything was woken.

0 commit comments

Comments
 (0)