Skip to content

Commit 50d4175

Browse files
committed
Add client and service side apis for limiting the send/recv msg size. Update MethodConfig struct
1 parent cdee119 commit 50d4175

6 files changed

Lines changed: 868 additions & 150 deletions

File tree

call.go

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ import (
5151
//
5252
// TODO(zhaoq): Check whether the received message sequence is valid.
5353
// TODO ctx is used for stats collection and processing. It is the context passed from the application.
54-
func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
54+
func recvResponse(ctx context.Context, dopts dialOptions, msgSizeLimit int, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) {
5555
// Try to acquire header metadata from the server if there is any.
5656
defer func() {
5757
if err != nil {
@@ -72,7 +72,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
7272
}
7373
}
7474
for {
75-
if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil {
75+
if err = recv(p, dopts.codec, stream, dopts.dc, reply, msgSizeLimit, inPayload); err != nil {
7676
if err == io.EOF {
7777
break
7878
}
@@ -92,7 +92,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
9292
}
9393

9494
// sendRequest writes out various information of an RPC such as Context and Message.
95-
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
95+
func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, msgSizeLimit int, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
9696
stream, err := t.NewStream(ctx, callHdr)
9797
if err != nil {
9898
return nil, err
@@ -121,6 +121,9 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
121121
if err != nil {
122122
return nil, Errorf(codes.Internal, "grpc: %v", err)
123123
}
124+
if len(outBuf) > msgSizeLimit {
125+
return nil, Errorf(codes.InvalidArgument, "Sent message larger than max (%d vs. %d)", len(outBuf), msgSizeLimit)
126+
}
124127
err = t.Write(stream, outBuf, opts)
125128
if err == nil && outPayload != nil {
126129
outPayload.SentTime = time.Now()
@@ -146,15 +149,49 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
146149
return invoke(ctx, method, args, reply, cc, opts...)
147150
}
148151

152+
const defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4
153+
const defaultClientMaxSendMessageSize = 1024 * 1024 * 4
154+
155+
func min(a, b int) int {
156+
if a < b {
157+
return a
158+
}
159+
return b
160+
}
161+
149162
func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) {
150163
c := defaultCallInfo
151-
if mc, ok := cc.getMethodConfig(method); ok {
152-
c.failFast = !mc.WaitForReady
153-
if mc.Timeout > 0 {
164+
maxReceiveMessageSize := defaultClientMaxReceiveMessageSize
165+
maxSendMessageSize := defaultClientMaxSendMessageSize
166+
if mc, ok := cc.GetMethodConfig(method); ok {
167+
if mc.WaitForReady != nil {
168+
c.failFast = !*mc.WaitForReady
169+
}
170+
171+
if mc.Timeout != nil && *mc.Timeout >= 0 {
154172
var cancel context.CancelFunc
155-
ctx, cancel = context.WithTimeout(ctx, mc.Timeout)
173+
ctx, cancel = context.WithTimeout(ctx, *mc.Timeout)
156174
defer cancel()
157175
}
176+
177+
if mc.MaxReqSize != nil && cc.dopts.maxSendMessageSize >= 0 {
178+
maxSendMessageSize = min(*mc.MaxReqSize, cc.dopts.maxSendMessageSize)
179+
} else if mc.MaxReqSize != nil {
180+
maxSendMessageSize = *mc.MaxReqSize
181+
}
182+
183+
if mc.MaxRespSize != nil && cc.dopts.maxReceiveMessageSize >= 0 {
184+
maxReceiveMessageSize = min(*mc.MaxRespSize, cc.dopts.maxReceiveMessageSize)
185+
} else if mc.MaxRespSize != nil {
186+
maxReceiveMessageSize = *mc.MaxRespSize
187+
}
188+
} else {
189+
if cc.dopts.maxSendMessageSize >= 0 {
190+
maxSendMessageSize = cc.dopts.maxSendMessageSize
191+
}
192+
if cc.dopts.maxReceiveMessageSize >= 0 {
193+
maxReceiveMessageSize = cc.dopts.maxReceiveMessageSize
194+
}
158195
}
159196
for _, o := range opts {
160197
if err := o.before(&c); err != nil {
@@ -245,7 +282,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
245282
if c.traceInfo.tr != nil {
246283
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
247284
}
248-
stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, callHdr, t, args, topts)
285+
stream, err = sendRequest(ctx, cc.dopts, cc.dopts.cp, maxSendMessageSize, callHdr, t, args, topts)
249286
if err != nil {
250287
if put != nil {
251288
put()
@@ -262,7 +299,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
262299
}
263300
return toRPCErr(err)
264301
}
265-
err = recvResponse(ctx, cc.dopts, t, &c, stream, reply)
302+
err = recvResponse(ctx, cc.dopts, maxReceiveMessageSize, t, &c, stream, reply)
266303
if err != nil {
267304
if put != nil {
268305
put()

clientconn.go

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -88,30 +88,45 @@ var (
8888
// dialOptions configure a Dial call. dialOptions are set by the DialOption
8989
// values passed to Dial.
9090
type dialOptions struct {
91-
unaryInt UnaryClientInterceptor
92-
streamInt StreamClientInterceptor
93-
codec Codec
94-
cp Compressor
95-
dc Decompressor
96-
bs backoffStrategy
97-
balancer Balancer
98-
block bool
99-
insecure bool
100-
timeout time.Duration
101-
scChan <-chan ServiceConfig
102-
copts transport.ConnectOptions
103-
maxMsgSize int
91+
unaryInt UnaryClientInterceptor
92+
streamInt StreamClientInterceptor
93+
codec Codec
94+
cp Compressor
95+
dc Decompressor
96+
bs backoffStrategy
97+
balancer Balancer
98+
block bool
99+
insecure bool
100+
timeout time.Duration
101+
scChan <-chan ServiceConfig
102+
copts transport.ConnectOptions
103+
maxReceiveMessageSize int
104+
maxSendMessageSize int
104105
}
105106

106107
const defaultClientMaxMsgSize = math.MaxInt32
107108

108109
// DialOption configures how we set up the connection.
109110
type DialOption func(*dialOptions)
110111

111-
// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive.
112+
// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive. This function is for backward API compatibility. It has essentially the same functionality as WithMaxReceiveMessageSize.
112113
func WithMaxMsgSize(s int) DialOption {
113114
return func(o *dialOptions) {
114-
o.maxMsgSize = s
115+
o.maxReceiveMessageSize = s
116+
}
117+
}
118+
119+
// WithMaxReceiveMessageSize returns a DialOption which sets the maximum message size the client can receive. Negative input is invalid and has the same effect as not setting the field.
120+
func WithMaxReceiveMessageSize(s int) DialOption {
121+
return func(o *dialOptions) {
122+
o.maxReceiveMessageSize = s
123+
}
124+
}
125+
126+
// WithMaxSendMessageSize returns a DialOption which sets the maximum message size the client can send. Negative input is invalid and has the same effect as not seeting the field.
127+
func WithMaxSendMessageSize(s int) DialOption {
128+
return func(o *dialOptions) {
129+
o.maxSendMessageSize = s
115130
}
116131
}
117132

@@ -307,7 +322,11 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
307322
conns: make(map[Address]*addrConn),
308323
}
309324
cc.ctx, cc.cancel = context.WithCancel(context.Background())
310-
cc.dopts.maxMsgSize = defaultClientMaxMsgSize
325+
326+
// initialize maxReceiveMessageSize and maxSendMessageSize to -1 before applying DialOption functions to distinguish whether the user set the message limit or not.
327+
cc.dopts.maxReceiveMessageSize = -1
328+
cc.dopts.maxSendMessageSize = -1
329+
311330
for _, opt := range opts {
312331
opt(&cc.dopts)
313332
}
@@ -609,11 +628,16 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err
609628
return nil
610629
}
611630

631+
// GetMethodConfig gets the method config of the input method. If there's no exact match for the input method (i.e. /service/method), we will return the default config for all methods under the service (/service/).
612632
// TODO: Avoid the locking here.
613-
func (cc *ClientConn) getMethodConfig(method string) (m MethodConfig, ok bool) {
633+
func (cc *ClientConn) GetMethodConfig(method string) (m MethodConfig, ok bool) {
614634
cc.mu.RLock()
615635
defer cc.mu.RUnlock()
616636
m, ok = cc.sc.Methods[method]
637+
if !ok {
638+
i := strings.LastIndex(method, "/")
639+
m, ok = cc.sc.Methods[method[:i+1]]
640+
}
617641
return
618642
}
619643

rpc_util.go

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ type parser struct {
239239
// No other error values or types must be returned, which also means
240240
// that the underlying io.Reader must not return an incompatible
241241
// error.
242-
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
242+
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
243243
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
244244
return 0, nil, err
245245
}
@@ -250,8 +250,8 @@ func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err erro
250250
if length == 0 {
251251
return pf, nil, nil
252252
}
253-
if length > uint32(maxMsgSize) {
254-
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
253+
if length > uint32(maxReceiveMessageSize) {
254+
return 0, nil, Errorf(codes.InvalidArgument, "grpc: Received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
255255
}
256256
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
257257
// of making it for each message:
@@ -335,8 +335,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
335335
return nil
336336
}
337337

338-
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int, inPayload *stats.InPayload) error {
339-
pf, d, err := p.recvMsg(maxMsgSize)
338+
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, inPayload *stats.InPayload) error {
339+
pf, d, err := p.recvMsg(maxReceiveMessageSize)
340340
if err != nil {
341341
return err
342342
}
@@ -352,10 +352,10 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
352352
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
353353
}
354354
}
355-
if len(d) > maxMsgSize {
355+
if len(d) > maxReceiveMessageSize {
356356
// TODO: Revisit the error code. Currently keep it consistent with java
357357
// implementation.
358-
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
358+
return Errorf(codes.InvalidArgument, "grpc: Received message larger than max (%d vs. %d)", len(d), maxReceiveMessageSize)
359359
}
360360
if err := c.Unmarshal(d, m); err != nil {
361361
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
@@ -489,24 +489,22 @@ type MethodConfig struct {
489489
// WaitForReady indicates whether RPCs sent to this method should wait until
490490
// the connection is ready by default (!failfast). The value specified via the
491491
// gRPC client API will override the value set here.
492-
WaitForReady bool
492+
WaitForReady *bool
493493
// Timeout is the default timeout for RPCs sent to this method. The actual
494494
// deadline used will be the minimum of the value specified here and the value
495495
// set by the application via the gRPC client API. If either one is not set,
496496
// then the other will be used. If neither is set, then the RPC has no deadline.
497-
Timeout time.Duration
497+
Timeout *time.Duration
498498
// MaxReqSize is the maximum allowed payload size for an individual request in a
499499
// stream (client->server) in bytes. The size which is measured is the serialized
500500
// payload after per-message compression (but before stream compression) in bytes.
501501
// The actual value used is the minumum of the value specified here and the value set
502502
// by the application via the gRPC client API. If either one is not set, then the other
503503
// will be used. If neither is set, then the built-in default is used.
504-
// TODO: support this.
505-
MaxReqSize uint32
504+
MaxReqSize *int
506505
// MaxRespSize is the maximum allowed payload size for an individual response in a
507506
// stream (server->client) in bytes.
508-
// TODO: support this.
509-
MaxRespSize uint32
507+
MaxRespSize *int
510508
}
511509

512510
// ServiceConfig is provided by the service provider and contains parameters for how

0 commit comments

Comments
 (0)