Skip to content

Commit 91d2c0a

Browse files
authored
Add contexts that use FakeClock rather than the system time. (#92)
1 parent 7e524bd commit 91d2c0a

4 files changed

Lines changed: 416 additions & 41 deletions

File tree

clockwork.go

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -157,44 +157,53 @@ func (fc *FakeClock) NewTicker(d time.Duration) Ticker {
157157
ft = &fakeTicker{
158158
firer: newFirer(),
159159
d: d,
160-
reset: func(d time.Duration) { fc.set(ft, d) },
161-
stop: func() { fc.stop(ft) },
160+
reset: func(d time.Duration) {
161+
fc.l.Lock()
162+
defer fc.l.Unlock()
163+
fc.setExpirer(ft, d)
164+
},
165+
stop: func() { fc.stop(ft) },
162166
}
163-
fc.set(ft, d)
167+
fc.l.Lock()
168+
defer fc.l.Unlock()
169+
fc.setExpirer(ft, d)
164170
return ft
165171
}
166172

167173
// NewTimer returns a Timer that will fire only after calls to
168174
// fakeClock.Advance() have moved the clock past the given duration.
169175
func (fc *FakeClock) NewTimer(d time.Duration) Timer {
170-
return fc.newTimer(d, nil)
176+
t, _ := fc.newTimer(d, nil)
177+
return t
171178
}
172179

173180
// AfterFunc mimics [time.AfterFunc]; it returns a Timer that will invoke the
174181
// given function only after calls to fakeClock.Advance() have moved the clock
175182
// past the given duration.
176183
func (fc *FakeClock) AfterFunc(d time.Duration, f func()) Timer {
177-
return fc.newTimer(d, f)
184+
t, _ := fc.newTimer(d, f)
185+
return t
178186
}
179187

180-
// newTimer returns a new timer, using an optional afterFunc.
181-
func (fc *FakeClock) newTimer(d time.Duration, afterfunc func()) *fakeTimer {
182-
var ft *fakeTimer
183-
ft = &fakeTimer{
184-
firer: newFirer(),
185-
reset: func(d time.Duration) bool {
186-
fc.l.Lock()
187-
defer fc.l.Unlock()
188-
// fc.l must be held across the calls to stopExpirer & setExpirer.
189-
stopped := fc.stopExpirer(ft)
190-
fc.setExpirer(ft, d)
191-
return stopped
192-
},
193-
stop: func() bool { return fc.stop(ft) },
188+
// newTimer returns a new timer using an optional afterFunc and the time that
189+
// timer expires.
190+
func (fc *FakeClock) newTimer(d time.Duration, afterfunc func()) (*fakeTimer, time.Time) {
191+
ft := newFakeTimer(fc, afterfunc)
192+
fc.l.Lock()
193+
defer fc.l.Unlock()
194+
fc.setExpirer(ft, d)
195+
return ft, ft.expiry()
196+
}
194197

195-
afterFunc: afterfunc,
196-
}
197-
fc.set(ft, d)
198+
// newTimerAtTime is like newTimer, but uses a time instead of a duration.
199+
//
200+
// It is used to ensure FakeClock's lock is held constant through calling
201+
// fc.After(t.Sub(fc.Now())). It should not be exposed externally.
202+
func (fc *FakeClock) newTimerAtTime(t time.Time, afterfunc func()) *fakeTimer {
203+
ft := newFakeTimer(fc, afterfunc)
204+
fc.l.Lock()
205+
defer fc.l.Unlock()
206+
fc.setExpirer(ft, t.Sub(fc.time))
198207
return ft
199208
}
200209

@@ -289,13 +298,6 @@ func (fc *FakeClock) stopExpirer(e expirer) bool {
289298
return true
290299
}
291300

292-
// set sets an expirer to expire at a future point in time.
293-
func (fc *FakeClock) set(e expirer, d time.Duration) {
294-
fc.l.Lock()
295-
defer fc.l.Unlock()
296-
fc.setExpirer(e, d)
297-
}
298-
299301
// setExpirer sets an expirer to expire at a future point in time.
300302
//
301303
// The caller must hold fc.l.
@@ -316,16 +318,14 @@ func (fc *FakeClock) setExpirer(e expirer, d time.Duration) {
316318
})
317319

318320
// Notify blockers of our new waiter.
319-
var blocked []*blocker
320321
count := len(fc.waiters)
321-
for _, b := range fc.blockers {
322+
fc.blockers = slices.DeleteFunc(fc.blockers, func(b *blocker) bool {
322323
if b.count <= count {
323324
close(b.ch)
324-
continue
325+
return true
325326
}
326-
blocked = append(blocked, b)
327-
}
328-
fc.blockers = blocked
327+
return false
328+
})
329329
}
330330

331331
// firer is used by fakeTimer and fakeTicker used to help implement expirer.

context.go

Lines changed: 141 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ package clockwork
22

33
import (
44
"context"
5+
"fmt"
6+
"sync"
7+
"time"
58
)
69

710
// contextKey is private to this package so we can ensure uniqueness here. This
811
// type identifies context values provided by this package.
912
type contextKey string
1013

11-
// keyClock provides a clock for injecting during tests. If absent, a real clock should be used.
14+
// keyClock provides a clock for injecting during tests. If absent, a real clock
15+
// should be used.
1216
var keyClock = contextKey("clock") // clockwork.Clock
1317

1418
// AddToContext creates a derived context that references the specified clock.
@@ -21,10 +25,145 @@ func AddToContext(ctx context.Context, clock Clock) context.Context {
2125
return context.WithValue(ctx, keyClock, clock)
2226
}
2327

24-
// FromContext extracts a clock from the context. If not present, a real clock is returned.
28+
// FromContext extracts a clock from the context. If not present, a real clock
29+
// is returned.
2530
func FromContext(ctx context.Context) Clock {
2631
if clock, ok := ctx.Value(keyClock).(Clock); ok {
2732
return clock
2833
}
2934
return NewRealClock()
3035
}
36+
37+
// ErrFakeClockDeadlineExceeded is the error returned by [context.Context] when
38+
// the deadline passes on a context which uses a [FakeClock].
39+
//
40+
// It wraps a [context.DeadlineExceeded] error, i.e.:
41+
//
42+
// // The following is true for any Context whose deadline has been exceeded,
43+
// // including contexts made with clockwork.WithDeadline or clockwork.WithTimeout.
44+
//
45+
// errors.Is(ctx.Err(), context.DeadlineExceeded)
46+
//
47+
// // The following can only be true for contexts made
48+
// // with clockwork.WithDeadline or clockwork.WithTimeout.
49+
//
50+
// errors.Is(ctx.Err(), clockwork.ErrFakeClockDeadlineExceeded)
51+
var ErrFakeClockDeadlineExceeded error = fmt.Errorf("clockwork.FakeClock: %w", context.DeadlineExceeded)
52+
53+
// WithDeadline returns a context with a deadline based on a [FakeClock].
54+
//
55+
// The returned context ignores parent cancelation if the parent was cancelled
56+
// with a [context.DeadlineExceeded] error. Any other error returned by the
57+
// parent is treated normally, cancelling the returned context.
58+
//
59+
// If the parent is cancelled with a [context.DeadlineExceeded] error, the only
60+
// way to then cancel the returned context is by calling the returned
61+
// context.CancelFunc.
62+
func WithDeadline(parent context.Context, clock Clock, t time.Time) (context.Context, context.CancelFunc) {
63+
if fc, ok := clock.(*FakeClock); ok {
64+
return newFakeClockContext(parent, t, fc.newTimerAtTime(t, nil).Chan())
65+
}
66+
return context.WithDeadline(parent, t)
67+
}
68+
69+
// WithTimeout returns a context with a timeout based on a [FakeClock].
70+
//
71+
// The returned context follows the same behaviors as [WithDeadline].
72+
func WithTimeout(parent context.Context, clock Clock, d time.Duration) (context.Context, context.CancelFunc) {
73+
if fc, ok := clock.(*FakeClock); ok {
74+
t, deadline := fc.newTimer(d, nil)
75+
return newFakeClockContext(parent, deadline, t.Chan())
76+
}
77+
return context.WithTimeout(parent, d)
78+
}
79+
80+
// fakeClockContext implements context.Context, using a fake clock for its
81+
// deadline.
82+
//
83+
// It ignores parent cancellation if the parent is cancelled with
84+
// context.DeadlineExceeded.
85+
type fakeClockContext struct {
86+
parent context.Context
87+
deadline time.Time // The user-facing deadline based on the fake clock's time.
88+
89+
// Tracks timeout/deadline cancellation.
90+
timerDone <-chan time.Time
91+
92+
// Tracks manual calls to the cancel function.
93+
cancel func() // Closes cancelCalled wrapped in a sync.Once.
94+
cancelCalled chan struct{}
95+
96+
// The user-facing data from the context.Context interface.
97+
ctxDone chan struct{} // Returned by Done().
98+
err error // nil until ctxDone is ready to be closed.
99+
}
100+
101+
func newFakeClockContext(parent context.Context, deadline time.Time, timer <-chan time.Time) (context.Context, context.CancelFunc) {
102+
cancelCalled := make(chan struct{})
103+
ctx := &fakeClockContext{
104+
parent: parent,
105+
deadline: deadline,
106+
timerDone: timer,
107+
cancelCalled: cancelCalled,
108+
ctxDone: make(chan struct{}),
109+
cancel: sync.OnceFunc(func() {
110+
close(cancelCalled)
111+
}),
112+
}
113+
ready := make(chan struct{}, 1)
114+
go ctx.runCancel(ready)
115+
<-ready // Wait until the cancellation goroutine is running.
116+
return ctx, ctx.cancel
117+
}
118+
119+
func (c *fakeClockContext) Deadline() (time.Time, bool) {
120+
return c.deadline, true
121+
}
122+
123+
func (c *fakeClockContext) Done() <-chan struct{} {
124+
return c.ctxDone
125+
}
126+
127+
func (c *fakeClockContext) Err() error {
128+
<-c.Done() // Don't return the error before it is ready.
129+
return c.err
130+
}
131+
132+
func (c *fakeClockContext) Value(key any) any {
133+
return c.parent.Value(key)
134+
}
135+
136+
// runCancel runs the fakeClockContext's cancel goroutine and returns the
137+
// fakeClockContext's cancel function.
138+
//
139+
// fakeClockContext is then cancelled when any of the following occur:
140+
//
141+
// - The fakeClockContext.done channel is closed by its timer.
142+
// - The returned CancelFunc is executed.
143+
// - The fakeClockContext's parent context is cancelled with an error other
144+
// than context.DeadlineExceeded.
145+
func (c *fakeClockContext) runCancel(ready chan struct{}) {
146+
parentDone := c.parent.Done()
147+
148+
// Close ready when done, just in case the ready signal races with other
149+
// branches of our select statement below.
150+
defer close(ready)
151+
152+
for c.err == nil {
153+
select {
154+
case <-c.timerDone:
155+
c.err = ErrFakeClockDeadlineExceeded
156+
case <-c.cancelCalled:
157+
c.err = context.Canceled
158+
case <-parentDone:
159+
c.err = c.parent.Err()
160+
161+
case ready <- struct{}{}:
162+
// Signals the cancellation goroutine has begun, in an attempt to minimize
163+
// race conditions related to goroutine startup time.
164+
ready = nil // This case statement can only fire once.
165+
}
166+
}
167+
close(c.ctxDone)
168+
return
169+
}

0 commit comments

Comments
 (0)