@@ -20,11 +20,13 @@ import (
2020 "context"
2121 "flag"
2222 "fmt"
23+ "io"
2324 "net"
2425 "os"
2526 "runtime"
2627 "runtime/debug"
2728 "strings"
29+ "sync"
2830 "time"
2931
3032 v1 "github.com/containerd/containerd/api/services/ttrpc/events/v1"
@@ -46,8 +48,14 @@ type Client struct {
4648 signals chan os.Signal
4749}
4850
51+ // Publisher for events
52+ type Publisher interface {
53+ events.Publisher
54+ io.Closer
55+ }
56+
4957// Init func for the creation of a shim server
50- type Init func (context.Context , string , events. Publisher ) (Shim , error )
58+ type Init func (context.Context , string , Publisher , func () ) (Shim , error )
5159
5260// Shim server interface
5361type Shim interface {
@@ -156,24 +164,28 @@ func run(id string, initFunc Init, config Config) error {
156164 return err
157165 }
158166 }
159-
160- publisher := & remoteEventsPublisher {
161- address : fmt .Sprintf ("%s.ttrpc" , addressFlag ),
162- }
163- conn , err := connect (publisher .address , dialer )
167+ address := fmt .Sprintf ("%s.ttrpc" , addressFlag )
168+ conn , err := connect (address , dialer )
164169 if err != nil {
165170 return err
166171 }
167- defer conn .Close ()
172+ publisher := & remoteEventsPublisher {
173+ address : address ,
174+ conn : conn ,
175+ closed : make (chan struct {}),
176+ }
177+ defer publisher .Close ()
178+
168179 publisher .client = v1 .NewEventsClient (ttrpc .NewClient (conn ))
169180 if namespaceFlag == "" {
170181 return fmt .Errorf ("shim namespace cannot be empty" )
171182 }
172183 ctx := namespaces .WithNamespace (context .Background (), namespaceFlag )
173184 ctx = context .WithValue (ctx , OptsKey {}, Opts {BundlePath : bundlePath , Debug : debugFlag })
174185 ctx = log .WithLogger (ctx , log .G (ctx ).WithField ("runtime" , id ))
186+ ctx , cancel := context .WithCancel (ctx )
175187
176- service , err := initFunc (ctx , idFlag , publisher )
188+ service , err := initFunc (ctx , idFlag , publisher , cancel )
177189 if err != nil {
178190 return err
179191 }
@@ -183,7 +195,7 @@ func run(id string, initFunc Init, config Config) error {
183195 "pid" : os .Getpid (),
184196 "namespace" : namespaceFlag ,
185197 })
186- go handleSignals (logger , signals )
198+ go handleSignals (ctx , logger , signals )
187199 response , err := service .Cleanup (ctx )
188200 if err != nil {
189201 return err
@@ -210,7 +222,17 @@ func run(id string, initFunc Init, config Config) error {
210222 return err
211223 }
212224 client := NewShimClient (ctx , service , signals )
213- return client .Serve ()
225+ if err := client .Serve (); err != nil {
226+ if err != context .Canceled {
227+ return err
228+ }
229+ }
230+ select {
231+ case <- publisher .Done ():
232+ return nil
233+ case <- time .After (5 * time .Second ):
234+ return errors .New ("publisher not closed" )
235+ }
214236 }
215237}
216238
@@ -254,7 +276,7 @@ func (s *Client) Serve() error {
254276 dumpStacks (logger )
255277 }
256278 }()
257- return handleSignals (logger , s .signals )
279+ return handleSignals (s . context , logger , s .signals )
258280}
259281
260282// serve serves the ttrpc API over a unix socket at the provided path
@@ -291,7 +313,22 @@ func dumpStacks(logger *logrus.Entry) {
291313
292314type remoteEventsPublisher struct {
293315 address string
316+ conn net.Conn
294317 client v1.EventsService
318+ closed chan struct {}
319+ closer sync.Once
320+ }
321+
322+ func (l * remoteEventsPublisher ) Done () <- chan struct {} {
323+ return l .closed
324+ }
325+
326+ func (l * remoteEventsPublisher ) Close () (err error ) {
327+ l .closer .Do (func () {
328+ err = l .conn .Close ()
329+ close (l .closed )
330+ })
331+ return err
295332}
296333
297334func (l * remoteEventsPublisher ) Publish (ctx context.Context , topic string , event events.Event ) error {
0 commit comments