@@ -43,10 +43,13 @@ type Client struct {
4343 channel * channel
4444 calls chan * callRequest
4545
46- closed chan struct {}
47- closeOnce sync.Once
48- closeFunc func ()
49- done chan struct {}
46+ ctx context.Context
47+ closed func ()
48+
49+ closeOnce sync.Once
50+ userCloseFunc func ()
51+
52+ errOnce sync.Once
5053 err error
5154 interceptor UnaryClientInterceptor
5255}
@@ -57,7 +60,7 @@ type ClientOpts func(c *Client)
5760// WithOnClose sets the close func whenever the client's Close() method is called
5861func WithOnClose (onClose func ()) ClientOpts {
5962 return func (c * Client ) {
60- c .closeFunc = onClose
63+ c .userCloseFunc = onClose
6164 }
6265}
6366
@@ -69,15 +72,16 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
6972}
7073
7174func NewClient (conn net.Conn , opts ... ClientOpts ) * Client {
75+ ctx , cancel := context .WithCancel (context .Background ())
7276 c := & Client {
73- codec : codec {},
74- conn : conn ,
75- channel : newChannel (conn ),
76- calls : make (chan * callRequest ),
77- closed : make ( chan struct {}) ,
78- done : make ( chan struct {}) ,
79- closeFunc : func () {},
80- interceptor : defaultClientInterceptor ,
77+ codec : codec {},
78+ conn : conn ,
79+ channel : newChannel (conn ),
80+ calls : make (chan * callRequest ),
81+ closed : cancel ,
82+ ctx : ctx ,
83+ userCloseFunc : func () {},
84+ interceptor : defaultClientInterceptor ,
8185 }
8286
8387 for _ , o := range opts {
@@ -150,25 +154,24 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
150154 case <- ctx .Done ():
151155 return ctx .Err ()
152156 case c .calls <- call :
153- case <- c .done :
154- return c .err
157+ case <- c .ctx . Done () :
158+ return c .error ()
155159 }
156160
157161 select {
158162 case <- ctx .Done ():
159163 return ctx .Err ()
160164 case err := <- errs :
161165 return filterCloseErr (err )
162- case <- c .done :
163- return c .err
166+ case <- c .ctx . Done () :
167+ return c .error ()
164168 }
165169}
166170
167171func (c * Client ) Close () error {
168172 c .closeOnce .Do (func () {
169- close ( c .closed )
173+ c .closed ( )
170174 })
171-
172175 return nil
173176}
174177
@@ -178,51 +181,82 @@ type message struct {
178181 err error
179182}
180183
181- func (c * Client ) run () {
182- var (
183- streamID uint32 = 1
184- waiters = make (map [uint32 ]* callRequest )
185- calls = c .calls
186- incoming = make (chan * message )
187- shutdown = make (chan struct {})
188- shutdownErr error
189- )
184+ type receiver struct {
185+ wg * sync.WaitGroup
186+ messages chan * message
187+ err error
188+ }
190189
191- go func ( ) {
192- defer close ( shutdown )
190+ func ( r * receiver ) run ( ctx context. Context , c * channel ) {
191+ defer r . wg . Done ( )
193192
194- // start one more goroutine to recv messages without blocking.
195- for {
196- mh , p , err := c .channel .recv (context .TODO ())
193+ for {
194+ select {
195+ case <- ctx .Done ():
196+ r .err = ctx .Err ()
197+ return
198+ default :
199+ mh , p , err := c .recv ()
197200 if err != nil {
198201 _ , ok := status .FromError (err )
199202 if ! ok {
200203 // treat all errors that are not an rpc status as terminal.
201204 // all others poison the connection.
202- shutdownErr = err
205+ r . err = filterCloseErr ( err )
203206 return
204207 }
205208 }
206209 select {
207- case incoming <- & message {
210+ case r . messages <- & message {
208211 messageHeader : mh ,
209212 p : p [:mh .Length ],
210213 err : err ,
211214 }:
212- case <- c .done :
215+ case <- ctx .Done ():
216+ r .err = ctx .Err ()
213217 return
214218 }
215219 }
220+ }
221+ }
222+
223+ func (c * Client ) run () {
224+ var (
225+ streamID uint32 = 1
226+ waiters = make (map [uint32 ]* callRequest )
227+ calls = c .calls
228+ incoming = make (chan * message )
229+ receiversDone = make (chan struct {})
230+ wg sync.WaitGroup
231+ )
232+
233+ // broadcast the shutdown error to the remaining waiters.
234+ abortWaiters := func (wErr error ) {
235+ for _ , waiter := range waiters {
236+ waiter .errs <- wErr
237+ }
238+ }
239+ recv := & receiver {
240+ wg : & wg ,
241+ messages : incoming ,
242+ }
243+ wg .Add (1 )
244+
245+ go func () {
246+ wg .Wait ()
247+ close (receiversDone )
216248 }()
249+ go recv .run (c .ctx , c .channel )
217250
218- defer c .conn .Close ()
219- defer close (c .done )
220- defer c .closeFunc ()
251+ defer func () {
252+ c .conn .Close ()
253+ c .userCloseFunc ()
254+ }()
221255
222256 for {
223257 select {
224258 case call := <- calls :
225- if err := c .send (call . ctx , streamID , messageTypeRequest , call .req ); err != nil {
259+ if err := c .send (streamID , messageTypeRequest , call .req ); err != nil {
226260 call .errs <- err
227261 continue
228262 }
@@ -238,41 +272,42 @@ func (c *Client) run() {
238272
239273 call .errs <- c .recv (call .resp , msg )
240274 delete (waiters , msg .StreamID )
241- case <- shutdown :
242- if shutdownErr != nil {
243- shutdownErr = filterCloseErr (shutdownErr )
244- } else {
245- shutdownErr = ErrClosed
246- }
247-
248- shutdownErr = errors .Wrapf (shutdownErr , "ttrpc: client shutting down" )
249-
250- c .err = shutdownErr
251- for _ , waiter := range waiters {
252- waiter .errs <- shutdownErr
275+ case <- receiversDone :
276+ // all the receivers have exited
277+ if recv .err != nil {
278+ c .setError (recv .err )
253279 }
280+ // don't return out, let the close of the context trigger the abort of waiters
254281 c .Close ()
255- return
256- case <- c .closed :
257- if c .err == nil {
258- c .err = ErrClosed
259- }
260- // broadcast the shutdown error to the remaining waiters.
261- for _ , waiter := range waiters {
262- waiter .errs <- c .err
263- }
282+ case <- c .ctx .Done ():
283+ abortWaiters (c .error ())
264284 return
265285 }
266286 }
267287}
268288
269- func (c * Client ) send (ctx context.Context , streamID uint32 , mtype messageType , msg interface {}) error {
289+ func (c * Client ) error () error {
290+ c .errOnce .Do (func () {
291+ if c .err == nil {
292+ c .err = ErrClosed
293+ }
294+ })
295+ return c .err
296+ }
297+
298+ func (c * Client ) setError (err error ) {
299+ c .errOnce .Do (func () {
300+ c .err = err
301+ })
302+ }
303+
304+ func (c * Client ) send (streamID uint32 , mtype messageType , msg interface {}) error {
270305 p , err := c .codec .Marshal (msg )
271306 if err != nil {
272307 return err
273308 }
274309
275- return c .channel .send (ctx , streamID , mtype , p )
310+ return c .channel .send (streamID , mtype , p )
276311}
277312
278313func (c * Client ) recv (resp * Response , msg * message ) error {
@@ -293,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
293328//
294329// This purposely ignores errors with a wrapped cause.
295330func filterCloseErr (err error ) error {
296- if err == nil {
331+ switch {
332+ case err == nil :
297333 return nil
298- }
299-
300- if err == io .EOF {
334+ case err == io .EOF :
301335 return ErrClosed
302- }
303-
304- if strings .Contains (err .Error (), "use of closed network connection" ) {
336+ case errors .Cause (err ) == io .EOF :
305337 return ErrClosed
306- }
307-
308- // if we have an epipe on a write, we cast to errclosed
309- if oerr , ok := err .(* net.OpError ); ok && oerr .Op == "write" {
310- if serr , ok := oerr .Err .(* os.SyscallError ); ok && serr .Err == syscall .EPIPE {
311- return ErrClosed
338+ case strings .Contains (err .Error (), "use of closed network connection" ):
339+ return ErrClosed
340+ default :
341+ // if we have an epipe on a write, we cast to errclosed
342+ if oerr , ok := err .(* net.OpError ); ok && oerr .Op == "write" {
343+ if serr , ok := oerr .Err .(* os.SyscallError ); ok && serr .Err == syscall .EPIPE {
344+ return ErrClosed
345+ }
312346 }
313347 }
314348
0 commit comments