Skip to content

Commit 9c0db2b

Browse files
authored
Merge pull request #152 from klihub/devel/unary-interceptor-chaining
Implement support for unary interceptor chaining.
2 parents b2f0ada + 40f227d commit 9c0db2b

3 files changed

Lines changed: 321 additions & 3 deletions

File tree

client.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,42 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
7171
}
7272
}
7373

74+
// WithChainUnaryClientInterceptor sets the provided chain of client interceptors
75+
func WithChainUnaryClientInterceptor(interceptors ...UnaryClientInterceptor) ClientOpts {
76+
return func(c *Client) {
77+
if len(interceptors) == 0 {
78+
return
79+
}
80+
if c.interceptor != nil {
81+
interceptors = append([]UnaryClientInterceptor{c.interceptor}, interceptors...)
82+
}
83+
c.interceptor = func(
84+
ctx context.Context,
85+
req *Request,
86+
reply *Response,
87+
info *UnaryClientInfo,
88+
final Invoker,
89+
) error {
90+
return interceptors[0](ctx, req, reply, info,
91+
chainUnaryInterceptors(interceptors[1:], final, info))
92+
}
93+
}
94+
}
95+
96+
func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker, info *UnaryClientInfo) Invoker {
97+
if len(interceptors) == 0 {
98+
return final
99+
}
100+
return func(
101+
ctx context.Context,
102+
req *Request,
103+
reply *Response,
104+
) error {
105+
return interceptors[0](ctx, req, reply, info,
106+
chainUnaryInterceptors(interceptors[1:], final, info))
107+
}
108+
}
109+
74110
// NewClient creates a new ttrpc client using the given connection
75111
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
76112
ctx, cancel := context.WithCancel(context.Background())
@@ -85,13 +121,16 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
85121
ctx: ctx,
86122
userCloseFunc: func() {},
87123
userCloseWaitCh: make(chan struct{}),
88-
interceptor: defaultClientInterceptor,
89124
}
90125

91126
for _, o := range opts {
92127
o(c)
93128
}
94129

130+
if c.interceptor == nil {
131+
c.interceptor = defaultClientInterceptor
132+
}
133+
95134
go c.run()
96135
return c
97136
}

config.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717
package ttrpc
1818

19-
import "errors"
19+
import (
20+
"context"
21+
"errors"
22+
)
2023

2124
type serverConfig struct {
2225
handshaker Handshaker
@@ -44,9 +47,40 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt {
4447
func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt {
4548
return func(c *serverConfig) error {
4649
if c.interceptor != nil {
47-
return errors.New("only one interceptor allowed per server")
50+
return errors.New("only one unchained interceptor allowed per server")
4851
}
4952
c.interceptor = i
5053
return nil
5154
}
5255
}
56+
57+
// WithChainUnaryServerInterceptor sets the provided chain of server interceptors
58+
func WithChainUnaryServerInterceptor(interceptors ...UnaryServerInterceptor) ServerOpt {
59+
return func(c *serverConfig) error {
60+
if len(interceptors) == 0 {
61+
return nil
62+
}
63+
if c.interceptor != nil {
64+
interceptors = append([]UnaryServerInterceptor{c.interceptor}, interceptors...)
65+
}
66+
c.interceptor = func(
67+
ctx context.Context,
68+
unmarshal Unmarshaler,
69+
info *UnaryServerInfo,
70+
method Method) (interface{}, error) {
71+
return interceptors[0](ctx, unmarshal, info,
72+
chainUnaryServerInterceptors(info, method, interceptors[1:]))
73+
}
74+
return nil
75+
}
76+
}
77+
78+
func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, interceptors []UnaryServerInterceptor) Method {
79+
if len(interceptors) == 0 {
80+
return method
81+
}
82+
return func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
83+
return interceptors[0](ctx, unmarshal, info,
84+
chainUnaryServerInterceptors(info, method, interceptors[1:]))
85+
}
86+
}

interceptor_test.go

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
Copyright The containerd Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package ttrpc
18+
19+
import (
20+
"context"
21+
"reflect"
22+
"strings"
23+
"testing"
24+
25+
"github.com/containerd/ttrpc/internal"
26+
)
27+
28+
func TestUnaryClientInterceptor(t *testing.T) {
29+
var (
30+
intercepted = false
31+
interceptor = func(ctx context.Context, req *Request, reply *Response, ci *UnaryClientInfo, i Invoker) error {
32+
intercepted = true
33+
return i(ctx, req, reply)
34+
}
35+
36+
ctx = context.Background()
37+
server = mustServer(t)(NewServer())
38+
testImpl = &testingServer{}
39+
addr, listener = newTestListener(t)
40+
client, cleanup = newTestClient(t, addr, WithUnaryClientInterceptor(interceptor))
41+
message = strings.Repeat("a", 16)
42+
reply = strings.Repeat(message, 2)
43+
)
44+
45+
defer listener.Close()
46+
defer cleanup()
47+
48+
registerTestingService(server, testImpl)
49+
50+
go server.Serve(ctx, listener)
51+
defer server.Shutdown(ctx)
52+
53+
request := &internal.TestPayload{
54+
Foo: message,
55+
}
56+
response := &internal.TestPayload{}
57+
58+
if err := client.Call(ctx, serviceName, "Test", request, response); err != nil {
59+
t.Fatalf("unexpected error: %v", err)
60+
}
61+
62+
if !intercepted {
63+
t.Fatalf("ttrpc client call not intercepted")
64+
}
65+
66+
if response.Foo != reply {
67+
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
68+
}
69+
}
70+
71+
func TestChainUnaryClientInterceptor(t *testing.T) {
72+
var (
73+
orderIdx = 0
74+
recorded = []string{}
75+
intercept = func(idx int, tag string) UnaryClientInterceptor {
76+
return func(ctx context.Context, req *Request, reply *Response, ci *UnaryClientInfo, i Invoker) error {
77+
if idx != orderIdx {
78+
t.Fatalf("unexpected interceptor invocation order (%d != %d)", orderIdx, idx)
79+
}
80+
recorded = append(recorded, tag)
81+
orderIdx++
82+
return i(ctx, req, reply)
83+
}
84+
}
85+
86+
ctx = context.Background()
87+
server = mustServer(t)(NewServer())
88+
testImpl = &testingServer{}
89+
addr, listener = newTestListener(t)
90+
client, cleanup = newTestClient(t, addr,
91+
WithChainUnaryClientInterceptor(),
92+
WithChainUnaryClientInterceptor(
93+
intercept(0, "seen it"),
94+
intercept(1, "been"),
95+
intercept(2, "there"),
96+
intercept(3, "done"),
97+
intercept(4, "that"),
98+
),
99+
)
100+
expected = []string{
101+
"seen it",
102+
"been",
103+
"there",
104+
"done",
105+
"that",
106+
}
107+
message = strings.Repeat("a", 16)
108+
reply = strings.Repeat(message, 2)
109+
)
110+
111+
defer listener.Close()
112+
defer cleanup()
113+
114+
registerTestingService(server, testImpl)
115+
116+
go server.Serve(ctx, listener)
117+
defer server.Shutdown(ctx)
118+
119+
request := &internal.TestPayload{
120+
Foo: message,
121+
}
122+
response := &internal.TestPayload{}
123+
if err := client.Call(ctx, serviceName, "Test", request, response); err != nil {
124+
t.Fatalf("unexpected error: %v", err)
125+
}
126+
127+
if !reflect.DeepEqual(recorded, expected) {
128+
t.Fatalf("unexpected ttrpc chained client unary interceptor order (%s != %s)",
129+
strings.Join(recorded, " "), strings.Join(expected, " "))
130+
}
131+
132+
if response.Foo != reply {
133+
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
134+
}
135+
}
136+
137+
func TestUnaryServerInterceptor(t *testing.T) {
138+
var (
139+
intercepted = false
140+
interceptor = func(ctx context.Context, unmarshal Unmarshaler, _ *UnaryServerInfo, method Method) (interface{}, error) {
141+
intercepted = true
142+
return method(ctx, unmarshal)
143+
}
144+
145+
ctx = context.Background()
146+
server = mustServer(t)(NewServer(WithUnaryServerInterceptor(interceptor)))
147+
testImpl = &testingServer{}
148+
addr, listener = newTestListener(t)
149+
client, cleanup = newTestClient(t, addr)
150+
message = strings.Repeat("a", 16)
151+
reply = strings.Repeat(message, 2)
152+
)
153+
154+
defer listener.Close()
155+
defer cleanup()
156+
157+
registerTestingService(server, testImpl)
158+
159+
go server.Serve(ctx, listener)
160+
defer server.Shutdown(ctx)
161+
162+
request := &internal.TestPayload{
163+
Foo: message,
164+
}
165+
response := &internal.TestPayload{}
166+
if err := client.Call(ctx, serviceName, "Test", request, response); err != nil {
167+
t.Fatalf("unexpected error: %v", err)
168+
}
169+
170+
if !intercepted {
171+
t.Fatalf("ttrpc server call not intercepted")
172+
}
173+
174+
if response.Foo != reply {
175+
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
176+
}
177+
}
178+
179+
func TestChainUnaryServerInterceptor(t *testing.T) {
180+
var (
181+
orderIdx = 0
182+
recorded = []string{}
183+
intercept = func(idx int, tag string) UnaryServerInterceptor {
184+
return func(ctx context.Context, unmarshal Unmarshaler, _ *UnaryServerInfo, method Method) (interface{}, error) {
185+
if orderIdx != idx {
186+
t.Fatalf("unexpected interceptor invocation order (%d != %d)", orderIdx, idx)
187+
}
188+
recorded = append(recorded, tag)
189+
orderIdx++
190+
return method(ctx, unmarshal)
191+
}
192+
}
193+
194+
ctx = context.Background()
195+
server = mustServer(t)(NewServer(
196+
WithUnaryServerInterceptor(
197+
intercept(0, "seen it"),
198+
),
199+
WithChainUnaryServerInterceptor(
200+
intercept(1, "been"),
201+
intercept(2, "there"),
202+
intercept(3, "done"),
203+
intercept(4, "that"),
204+
),
205+
))
206+
expected = []string{
207+
"seen it",
208+
"been",
209+
"there",
210+
"done",
211+
"that",
212+
}
213+
testImpl = &testingServer{}
214+
addr, listener = newTestListener(t)
215+
client, cleanup = newTestClient(t, addr)
216+
message = strings.Repeat("a", 16)
217+
reply = strings.Repeat(message, 2)
218+
)
219+
220+
defer listener.Close()
221+
defer cleanup()
222+
223+
registerTestingService(server, testImpl)
224+
225+
go server.Serve(ctx, listener)
226+
defer server.Shutdown(ctx)
227+
228+
request := &internal.TestPayload{
229+
Foo: message,
230+
}
231+
response := &internal.TestPayload{}
232+
233+
if err := client.Call(ctx, serviceName, "Test", request, response); err != nil {
234+
t.Fatalf("unexpected error: %v", err)
235+
}
236+
237+
if !reflect.DeepEqual(recorded, expected) {
238+
t.Fatalf("unexpected ttrpc chained server unary interceptor order (%s != %s)",
239+
strings.Join(recorded, " "), strings.Join(expected, " "))
240+
}
241+
242+
if response.Foo != reply {
243+
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
244+
}
245+
}

0 commit comments

Comments
 (0)