Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions context_watchdog.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@

package clickhouse

import "context"

// contextWatchdog is a helper function to run a callback when the context is done.
// it has a cancellation function to prevent the callback from running.
// It has a cancellation function to prevent the callback from running.
// Useful for interrupting some logic when the context is done,
// but you want to not bother about context cancellation if your logic is already done.
// Example:
Expand All @@ -15,17 +14,16 @@ func contextWatchdog(ctx context.Context, callback func()) (cancel func()) {
exit := make(chan struct{})

go func() {
for {
select {
case <-exit:
return
case <-ctx.Done():
callback()
}
select {
case <-exit:
return
case <-ctx.Done():
callback()
return
}
}()

return func() {
exit <- struct{}{}
close(exit)
}
}
89 changes: 89 additions & 0 deletions context_watchdog_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//go:build go1.25
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we removed once we upgrade the Go tool chain to v1.25.
#1689


package clickhouse

import (
"context"
"sync/atomic"
"testing"
"testing/synctest"
"time"

"github.com/stretchr/testify/assert"
)

func TestContextWatchdog(t *testing.T) {
t.Run("callback should be called once", func(t *testing.T) {
called := atomic.Int32{}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()

stopCW := contextWatchdog(ctx, func() {
called.Add(1)
})

<-ctx.Done()

// Give it some more time to make sure watch dog has enough time to
// call callback multiple times
time.Sleep(100 * time.Millisecond)
assert.Equal(t, int32(1), called.Load(), "callback should be called only once")

stopCW()
assert.Equal(t, int32(1), called.Load(), "callback should be called only once even after stopping watchdog")
})

t.Run("callback should not be called during normal exit before context cancellation", func(t *testing.T) {
called := atomic.Int32{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

stopCW := contextWatchdog(ctx, func() {
called.Add(1)
})
stopCW() // normal exit

assert.Equal(t, int32(0), called.Load(), "callback should not be called during normal exit")
})

t.Run("No goroutines should be left out after stopping ContextWatchdog", func(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
// when context is cancelled
called := atomic.Int32{}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()

stopCW := contextWatchdog(ctx, func() {
called.Add(1)
})

<-ctx.Done()

// Give it some more time to make sure watch dog has enough time to
// call callback multiple times
time.Sleep(100 * time.Millisecond)
assert.Equal(t, int32(1), called.Load(), "callback should be called only once")

stopCW()
assert.Equal(t, int32(1), called.Load(), "callback should be called only once even after stopping watchdog")

synctest.Wait()
})

synctest.Test(t, func(t *testing.T) {
// during normal exit
called := atomic.Int32{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

stopCW := contextWatchdog(ctx, func() {
called.Add(1)
})
stopCW() // normal exit

assert.Equal(t, int32(0), called.Load(), "callback should not be called during normal exit")

synctest.Wait()
})
})
}
Loading