Skip to content

Commit 1fb3814

Browse files
authored
Merge pull request #42 from crosbymichael/client
Refactor close handling for ttrpc clients
2 parents 5829a06 + 3afb82b commit 1fb3814

5 files changed

Lines changed: 119 additions & 89 deletions

File tree

channel.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package ttrpc
1818

1919
import (
2020
"bufio"
21-
"context"
2221
"encoding/binary"
2322
"io"
2423
"net"
@@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
9897
// returned will be valid and caller should send that along to
9998
// the correct consumer. The bytes on the underlying channel
10099
// will be discarded.
101-
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
100+
func (ch *channel) recv() (messageHeader, []byte, error) {
102101
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
103102
if err != nil {
104103
return messageHeader{}, nil, err
@@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
120119
return mh, p, nil
121120
}
122121

123-
func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
122+
func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
124123
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
125124
return err
126125
}

channel_test.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package ttrpc
1818

1919
import (
2020
"bytes"
21-
"context"
2221
"io"
2322
"net"
2423
"reflect"
@@ -31,7 +30,6 @@ import (
3130

3231
func TestReadWriteMessage(t *testing.T) {
3332
var (
34-
ctx = context.Background()
3533
w, r = net.Pipe()
3634
ch = newChannel(w)
3735
rch = newChannel(r)
@@ -46,7 +44,7 @@ func TestReadWriteMessage(t *testing.T) {
4644

4745
go func() {
4846
for i, msg := range messages {
49-
if err := ch.send(ctx, uint32(i), 1, msg); err != nil {
47+
if err := ch.send(uint32(i), 1, msg); err != nil {
5048
errs <- err
5149
return
5250
}
@@ -56,7 +54,7 @@ func TestReadWriteMessage(t *testing.T) {
5654
}()
5755

5856
for {
59-
_, p, err := rch.recv(ctx)
57+
_, p, err := rch.recv()
6058
if err != nil {
6159
if errors.Cause(err) != io.EOF {
6260
t.Fatal(err)
@@ -91,20 +89,19 @@ func TestReadWriteMessage(t *testing.T) {
9189

9290
func TestMessageOversize(t *testing.T) {
9391
var (
94-
ctx = context.Background()
9592
w, r = net.Pipe()
9693
wch, rch = newChannel(w), newChannel(r)
9794
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
9895
errs = make(chan error, 1)
9996
)
10097

10198
go func() {
102-
if err := wch.send(ctx, 1, 1, msg); err != nil {
99+
if err := wch.send(1, 1, msg); err != nil {
103100
errs <- err
104101
}
105102
}()
106103

107-
_, _, err := rch.recv(ctx)
104+
_, _, err := rch.recv()
108105
if err == nil {
109106
t.Fatalf("error expected reading with small buffer")
110107
}

client.go

Lines changed: 110 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,13 @@ type Client struct {
4343
channel *channel
4444
calls chan *callRequest
4545

46-
closed chan struct{}
47-
closeOnce sync.Once
48-
closeFunc func()
49-
done chan struct{}
46+
ctx context.Context
47+
closed func()
48+
49+
closeOnce sync.Once
50+
userCloseFunc func()
51+
52+
errOnce sync.Once
5053
err error
5154
interceptor UnaryClientInterceptor
5255
}
@@ -57,7 +60,7 @@ type ClientOpts func(c *Client)
5760
// WithOnClose sets the close func whenever the client's Close() method is called
5861
func WithOnClose(onClose func()) ClientOpts {
5962
return func(c *Client) {
60-
c.closeFunc = onClose
63+
c.userCloseFunc = onClose
6164
}
6265
}
6366

@@ -69,15 +72,16 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
6972
}
7073

7174
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
75+
ctx, cancel := context.WithCancel(context.Background())
7276
c := &Client{
73-
codec: codec{},
74-
conn: conn,
75-
channel: newChannel(conn),
76-
calls: make(chan *callRequest),
77-
closed: make(chan struct{}),
78-
done: make(chan struct{}),
79-
closeFunc: func() {},
80-
interceptor: defaultClientInterceptor,
77+
codec: codec{},
78+
conn: conn,
79+
channel: newChannel(conn),
80+
calls: make(chan *callRequest),
81+
closed: cancel,
82+
ctx: ctx,
83+
userCloseFunc: func() {},
84+
interceptor: defaultClientInterceptor,
8185
}
8286

8387
for _, o := range opts {
@@ -150,25 +154,24 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
150154
case <-ctx.Done():
151155
return ctx.Err()
152156
case c.calls <- call:
153-
case <-c.done:
154-
return c.err
157+
case <-c.ctx.Done():
158+
return c.error()
155159
}
156160

157161
select {
158162
case <-ctx.Done():
159163
return ctx.Err()
160164
case err := <-errs:
161165
return filterCloseErr(err)
162-
case <-c.done:
163-
return c.err
166+
case <-c.ctx.Done():
167+
return c.error()
164168
}
165169
}
166170

167171
func (c *Client) Close() error {
168172
c.closeOnce.Do(func() {
169-
close(c.closed)
173+
c.closed()
170174
})
171-
172175
return nil
173176
}
174177

@@ -178,51 +181,82 @@ type message struct {
178181
err error
179182
}
180183

181-
func (c *Client) run() {
182-
var (
183-
streamID uint32 = 1
184-
waiters = make(map[uint32]*callRequest)
185-
calls = c.calls
186-
incoming = make(chan *message)
187-
shutdown = make(chan struct{})
188-
shutdownErr error
189-
)
184+
type receiver struct {
185+
wg *sync.WaitGroup
186+
messages chan *message
187+
err error
188+
}
190189

191-
go func() {
192-
defer close(shutdown)
190+
func (r *receiver) run(ctx context.Context, c *channel) {
191+
defer r.wg.Done()
193192

194-
// start one more goroutine to recv messages without blocking.
195-
for {
196-
mh, p, err := c.channel.recv(context.TODO())
193+
for {
194+
select {
195+
case <-ctx.Done():
196+
r.err = ctx.Err()
197+
return
198+
default:
199+
mh, p, err := c.recv()
197200
if err != nil {
198201
_, ok := status.FromError(err)
199202
if !ok {
200203
// treat all errors that are not an rpc status as terminal.
201204
// all others poison the connection.
202-
shutdownErr = err
205+
r.err = filterCloseErr(err)
203206
return
204207
}
205208
}
206209
select {
207-
case incoming <- &message{
210+
case r.messages <- &message{
208211
messageHeader: mh,
209212
p: p[:mh.Length],
210213
err: err,
211214
}:
212-
case <-c.done:
215+
case <-ctx.Done():
216+
r.err = ctx.Err()
213217
return
214218
}
215219
}
220+
}
221+
}
222+
223+
func (c *Client) run() {
224+
var (
225+
streamID uint32 = 1
226+
waiters = make(map[uint32]*callRequest)
227+
calls = c.calls
228+
incoming = make(chan *message)
229+
receiversDone = make(chan struct{})
230+
wg sync.WaitGroup
231+
)
232+
233+
// broadcast the shutdown error to the remaining waiters.
234+
abortWaiters := func(wErr error) {
235+
for _, waiter := range waiters {
236+
waiter.errs <- wErr
237+
}
238+
}
239+
recv := &receiver{
240+
wg: &wg,
241+
messages: incoming,
242+
}
243+
wg.Add(1)
244+
245+
go func() {
246+
wg.Wait()
247+
close(receiversDone)
216248
}()
249+
go recv.run(c.ctx, c.channel)
217250

218-
defer c.conn.Close()
219-
defer close(c.done)
220-
defer c.closeFunc()
251+
defer func() {
252+
c.conn.Close()
253+
c.userCloseFunc()
254+
}()
221255

222256
for {
223257
select {
224258
case call := <-calls:
225-
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
259+
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
226260
call.errs <- err
227261
continue
228262
}
@@ -238,41 +272,42 @@ func (c *Client) run() {
238272

239273
call.errs <- c.recv(call.resp, msg)
240274
delete(waiters, msg.StreamID)
241-
case <-shutdown:
242-
if shutdownErr != nil {
243-
shutdownErr = filterCloseErr(shutdownErr)
244-
} else {
245-
shutdownErr = ErrClosed
246-
}
247-
248-
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
249-
250-
c.err = shutdownErr
251-
for _, waiter := range waiters {
252-
waiter.errs <- shutdownErr
275+
case <-receiversDone:
276+
// all the receivers have exited
277+
if recv.err != nil {
278+
c.setError(recv.err)
253279
}
280+
// don't return out, let the close of the context trigger the abort of waiters
254281
c.Close()
255-
return
256-
case <-c.closed:
257-
if c.err == nil {
258-
c.err = ErrClosed
259-
}
260-
// broadcast the shutdown error to the remaining waiters.
261-
for _, waiter := range waiters {
262-
waiter.errs <- c.err
263-
}
282+
case <-c.ctx.Done():
283+
abortWaiters(c.error())
264284
return
265285
}
266286
}
267287
}
268288

269-
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
289+
func (c *Client) error() error {
290+
c.errOnce.Do(func() {
291+
if c.err == nil {
292+
c.err = ErrClosed
293+
}
294+
})
295+
return c.err
296+
}
297+
298+
func (c *Client) setError(err error) {
299+
c.errOnce.Do(func() {
300+
c.err = err
301+
})
302+
}
303+
304+
func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
270305
p, err := c.codec.Marshal(msg)
271306
if err != nil {
272307
return err
273308
}
274309

275-
return c.channel.send(ctx, streamID, mtype, p)
310+
return c.channel.send(streamID, mtype, p)
276311
}
277312

278313
func (c *Client) recv(resp *Response, msg *message) error {
@@ -293,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
293328
//
294329
// This purposely ignores errors with a wrapped cause.
295330
func filterCloseErr(err error) error {
296-
if err == nil {
331+
switch {
332+
case err == nil:
297333
return nil
298-
}
299-
300-
if err == io.EOF {
334+
case err == io.EOF:
301335
return ErrClosed
302-
}
303-
304-
if strings.Contains(err.Error(), "use of closed network connection") {
336+
case errors.Cause(err) == io.EOF:
305337
return ErrClosed
306-
}
307-
308-
// if we have an epipe on a write, we cast to errclosed
309-
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
310-
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
311-
return ErrClosed
338+
case strings.Contains(err.Error(), "use of closed network connection"):
339+
return ErrClosed
340+
default:
341+
// if we have an epipe on a write, we cast to errclosed
342+
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
343+
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
344+
return ErrClosed
345+
}
312346
}
313347
}
314348

server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
344344
default: // proceed
345345
}
346346

347-
mh, p, err := ch.recv(ctx)
347+
mh, p, err := ch.recv()
348348
if err != nil {
349349
status, ok := status.FromError(err)
350350
if !ok {
@@ -441,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
441441
return
442442
}
443443

444-
if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
444+
if err := ch.send(response.id, messageTypeResponse, p); err != nil {
445445
logrus.WithError(err).Error("failed sending message on channel")
446446
return
447447
}

server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ func TestClientEOF(t *testing.T) {
351351
}
352352

353353
// shutdown the server so the client stops receiving stuff.
354-
if err := server.Shutdown(ctx); err != nil {
354+
if err := server.Close(); err != nil {
355355
t.Fatal(err)
356356
}
357357
if err := <-errs; err != ErrServerClosed {

0 commit comments

Comments
 (0)