Skip to content

Commit 4f0aeb5

Browse files
committed
client: Handle sending/receiving in separate goroutines
Changes the TTRPC client logic so that sending and receiving with the server are in completely independent goroutines, with shared state guarded by a mutex. Previously, sending/receiving were tied together by reliance on a coordinator goroutine. This led to issues where if the server was not reading from the connection, the client could get stuck sending a request, causing the client to not read responses from the server. See [1] for more details. The new design sets up separate sending/receiving goroutines. These share state in the form of the set of active calls that have been made to the server. This state is encapsulated in the callMap type and access is guarded by a mutex. The main event loop in `run` previously handled a lot of state management for the client. Now that most state is tracked by the callMap, it mostly exists to notice when the client is closed and take appropriate action to clean up. Also did some minor code cleanup. For instance, the code was previously written to support multiple receiver goroutines, though this was not actually used. I've removed this for now, since the code is simpler this way, and it's easy to add back if we actually need it in the future. [1] #72 Signed-off-by: Kevin Parsons <[email protected]>
1 parent 77fc8a4 commit 4f0aeb5

1 file changed

Lines changed: 116 additions & 75 deletions

File tree

client.go

Lines changed: 116 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -194,72 +194,131 @@ type message struct {
194194
err error
195195
}
196196

197-
type receiver struct {
198-
wg *sync.WaitGroup
199-
messages chan *message
200-
err error
197+
// callMap provides access to a map of active calls, guarded by a mutex.
198+
type callMap struct {
199+
m sync.Mutex
200+
activeCalls map[uint32]*callRequest
201+
closeErr error
201202
}
202203

203-
func (r *receiver) run(ctx context.Context, c *channel) {
204-
defer r.wg.Done()
204+
// newCallMap returns a new callMap with an empty set of active calls.
205+
func newCallMap() *callMap {
206+
return &callMap{
207+
activeCalls: make(map[uint32]*callRequest),
208+
}
209+
}
205210

206-
for {
207-
select {
208-
case <-ctx.Done():
209-
r.err = ctx.Err()
210-
return
211-
default:
212-
mh, p, err := c.recv()
213-
if err != nil {
214-
_, ok := status.FromError(err)
215-
if !ok {
216-
// treat all errors that are not an rpc status as terminal.
217-
// all others poison the connection.
218-
r.err = filterCloseErr(err)
219-
return
220-
}
221-
}
222-
select {
223-
case r.messages <- &message{
224-
messageHeader: mh,
225-
p: p[:mh.Length],
226-
err: err,
227-
}:
228-
case <-ctx.Done():
229-
r.err = ctx.Err()
230-
return
231-
}
232-
}
211+
// set adds a call entry to the map with the given streamID key.
212+
func (cm *callMap) set(streamID uint32, cr *callRequest) error {
213+
cm.m.Lock()
214+
defer cm.m.Unlock()
215+
if cm.closeErr != nil {
216+
return cm.closeErr
233217
}
218+
cm.activeCalls[streamID] = cr
219+
return nil
220+
}
221+
222+
// get looks up the call entry for the given streamID key, then removes it
223+
// from the map and returns it.
224+
func (cm *callMap) get(streamID uint32) (cr *callRequest, ok bool, err error) {
225+
cm.m.Lock()
226+
defer cm.m.Unlock()
227+
if cm.closeErr != nil {
228+
return nil, false, cm.closeErr
229+
}
230+
cr, ok = cm.activeCalls[streamID]
231+
if ok {
232+
delete(cm.activeCalls, streamID)
233+
}
234+
return
235+
}
236+
237+
// abort sends the given error to each active call, and clears the map.
238+
// Once abort has been called, any subsequent calls to the callMap will return the error passed to abort.
239+
func (cm *callMap) abort(err error) error {
240+
cm.m.Lock()
241+
defer cm.m.Unlock()
242+
if cm.closeErr != nil {
243+
return cm.closeErr
244+
}
245+
for streamID, call := range cm.activeCalls {
246+
call.errs <- err
247+
delete(cm.activeCalls, streamID)
248+
}
249+
cm.closeErr = err
250+
return nil
234251
}
235252

236253
func (c *Client) run() {
237254
var (
238-
streamID uint32 = 1
239-
waiters = make(map[uint32]*callRequest)
240-
calls = c.calls
241-
incoming = make(chan *message)
242-
receiversDone = make(chan struct{})
243-
wg sync.WaitGroup
255+
waiters = newCallMap()
256+
receiverDone = make(chan struct{})
244257
)
245258

246-
// broadcast the shutdown error to the remaining waiters.
247-
abortWaiters := func(wErr error) {
248-
for _, waiter := range waiters {
249-
waiter.errs <- wErr
259+
// Sender goroutine
260+
// Receives calls from dispatch, adds them to the set of active calls, and sends them
261+
// to the server.
262+
go func() {
263+
var streamID uint32 = 1
264+
for {
265+
select {
266+
case <-c.ctx.Done():
267+
return
268+
case call := <-c.calls:
269+
id := streamID
270+
streamID += 2 // enforce odd client initiated request ids
271+
if err := waiters.set(id, call); err != nil {
272+
call.errs <- err // errs is buffered so should not block.
273+
continue
274+
}
275+
if err := c.send(id, messageTypeRequest, call.req); err != nil {
276+
call.errs <- err // errs is buffered so should not block.
277+
waiters.get(id) // remove from waiters set
278+
}
279+
}
250280
}
251-
}
252-
recv := &receiver{
253-
wg: &wg,
254-
messages: incoming,
255-
}
256-
wg.Add(1)
281+
}()
257282

283+
// Receiver goroutine
284+
// Receives responses from the server, looks up the call info in the set of active calls,
285+
// and notifies the caller of the response.
258286
go func() {
259-
wg.Wait()
260-
close(receiversDone)
287+
defer close(receiverDone)
288+
for {
289+
select {
290+
case <-c.ctx.Done():
291+
c.setError(c.ctx.Err())
292+
return
293+
default:
294+
mh, p, err := c.channel.recv()
295+
if err != nil {
296+
_, ok := status.FromError(err)
297+
if !ok {
298+
// treat all errors that are not an rpc status as terminal.
299+
// all others poison the connection.
300+
c.setError(filterCloseErr(err))
301+
return
302+
}
303+
}
304+
msg := &message{
305+
messageHeader: mh,
306+
p: p[:mh.Length],
307+
err: err,
308+
}
309+
call, ok, err := waiters.get(mh.StreamID)
310+
if err != nil {
311+
logrus.Errorf("ttrpc: failed to look up active call: %s", err)
312+
continue
313+
}
314+
if !ok {
315+
logrus.Errorf("ttrpc: received message for unknown channel %v", mh.StreamID)
316+
continue
317+
}
318+
call.errs <- c.recv(call.resp, msg)
319+
}
320+
}
261321
}()
262-
go recv.run(c.ctx, c.channel)
263322

264323
defer func() {
265324
c.conn.Close()
@@ -269,32 +328,14 @@ func (c *Client) run() {
269328

270329
for {
271330
select {
272-
case call := <-calls:
273-
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
274-
call.errs <- err
275-
continue
276-
}
277-
278-
waiters[streamID] = call
279-
streamID += 2 // enforce odd client initiated request ids
280-
case msg := <-incoming:
281-
call, ok := waiters[msg.StreamID]
282-
if !ok {
283-
logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
284-
continue
285-
}
286-
287-
call.errs <- c.recv(call.resp, msg)
288-
delete(waiters, msg.StreamID)
289-
case <-receiversDone:
290-
// all the receivers have exited
291-
if recv.err != nil {
292-
c.setError(recv.err)
293-
}
331+
case <-receiverDone:
332+
// The receiver has exited.
294333
// don't return out, let the close of the context trigger the abort of waiters
295334
c.Close()
296335
case <-c.ctx.Done():
297-
abortWaiters(c.error())
336+
// Abort all active calls. This will also prevent any new calls from being added
337+
// to waiters.
338+
waiters.abort(c.error())
298339
return
299340
}
300341
}

0 commit comments

Comments
 (0)