Skip to content

Commit f984c9b

Browse files
committed
client: implement UnaryClientInterceptor chaining.
Add a WithChainUnaryClientInterceptor client option to allow using more that one client call interceptor which will then get chained and invoked in the order given. This should allow us to implement opentelemetry instrumentation as interceptors while allowing users to keep intercepting their client calls for other reasons at the same time. Signed-off-by: Krisztian Litkey <[email protected]>
1 parent 05e0d07 commit f984c9b

2 files changed

Lines changed: 175 additions & 1 deletion

File tree

client.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,42 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
7171
}
7272
}
7373

74+
// WithChainUnaryClientInterceptor sets the provided chain of client interceptors
75+
func WithChainUnaryClientInterceptor(interceptors ...UnaryClientInterceptor) ClientOpts {
76+
return func(c *Client) {
77+
if len(interceptors) == 0 {
78+
return
79+
}
80+
if c.interceptor != nil {
81+
interceptors = append([]UnaryClientInterceptor{c.interceptor}, interceptors...)
82+
}
83+
c.interceptor = func(
84+
ctx context.Context,
85+
req *Request,
86+
reply *Response,
87+
info *UnaryClientInfo,
88+
final Invoker,
89+
) error {
90+
return interceptors[0](ctx, req, reply, info,
91+
chainUnaryInterceptors(interceptors[1:], final, info))
92+
}
93+
}
94+
}
95+
96+
func chainUnaryInterceptors(interceptors []UnaryClientInterceptor, final Invoker, info *UnaryClientInfo) Invoker {
97+
if len(interceptors) == 0 {
98+
return final
99+
}
100+
return func(
101+
ctx context.Context,
102+
req *Request,
103+
reply *Response,
104+
) error {
105+
return interceptors[0](ctx, req, reply, info,
106+
chainUnaryInterceptors(interceptors[1:], final, info))
107+
}
108+
}
109+
74110
// NewClient creates a new ttrpc client using the given connection
75111
func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
76112
ctx, cancel := context.WithCancel(context.Background())
@@ -85,13 +121,16 @@ func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
85121
ctx: ctx,
86122
userCloseFunc: func() {},
87123
userCloseWaitCh: make(chan struct{}),
88-
interceptor: defaultClientInterceptor,
89124
}
90125

91126
for _, o := range opts {
92127
o(c)
93128
}
94129

130+
if c.interceptor == nil {
131+
c.interceptor = defaultClientInterceptor
132+
}
133+
95134
go c.run()
96135
return c
97136
}

interceptor_test.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
Copyright The containerd Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package ttrpc
18+
19+
import (
20+
"context"
21+
"reflect"
22+
"strings"
23+
"testing"
24+
25+
"github.com/containerd/ttrpc/internal"
26+
)
27+
28+
func TestUnaryClientInterceptor(t *testing.T) {
29+
var (
30+
intercepted = false
31+
interceptor = func(ctx context.Context, req *Request, reply *Response, ci *UnaryClientInfo, i Invoker) error {
32+
intercepted = true
33+
return i(ctx, req, reply)
34+
}
35+
36+
ctx = context.Background()
37+
server = mustServer(t)(NewServer())
38+
testImpl = &testingServer{}
39+
addr, listener = newTestListener(t)
40+
client, cleanup = newTestClient(t, addr, WithUnaryClientInterceptor(interceptor))
41+
message = strings.Repeat("a", 16)
42+
reply = strings.Repeat(message, 2)
43+
)
44+
45+
defer listener.Close()
46+
defer cleanup()
47+
48+
registerTestingService(server, testImpl)
49+
50+
go server.Serve(ctx, listener)
51+
defer server.Shutdown(ctx)
52+
53+
request := &internal.TestPayload{
54+
Foo: message,
55+
}
56+
response := &internal.TestPayload{}
57+
58+
if err := client.Call(ctx, serviceName, "Test", request, response); err != nil {
59+
t.Fatalf("unexpected error: %v", err)
60+
}
61+
62+
if !intercepted {
63+
t.Fatalf("ttrpc client call not intercepted")
64+
}
65+
66+
if response.Foo != reply {
67+
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
68+
}
69+
}
70+
71+
func TestChainUnaryClientInterceptor(t *testing.T) {
72+
var (
73+
orderIdx = 0
74+
recorded = []string{}
75+
intercept = func(idx int, tag string) UnaryClientInterceptor {
76+
return func(ctx context.Context, req *Request, reply *Response, ci *UnaryClientInfo, i Invoker) error {
77+
if idx != orderIdx {
78+
t.Fatalf("unexpected interceptor invocation order (%d != %d)", orderIdx, idx)
79+
}
80+
recorded = append(recorded, tag)
81+
orderIdx++
82+
return i(ctx, req, reply)
83+
}
84+
}
85+
86+
ctx = context.Background()
87+
server = mustServer(t)(NewServer())
88+
testImpl = &testingServer{}
89+
addr, listener = newTestListener(t)
90+
client, cleanup = newTestClient(t, addr,
91+
WithChainUnaryClientInterceptor(),
92+
WithChainUnaryClientInterceptor(
93+
intercept(0, "seen it"),
94+
intercept(1, "been"),
95+
intercept(2, "there"),
96+
intercept(3, "done"),
97+
intercept(4, "that"),
98+
),
99+
)
100+
expected = []string{
101+
"seen it",
102+
"been",
103+
"there",
104+
"done",
105+
"that",
106+
}
107+
message = strings.Repeat("a", 16)
108+
reply = strings.Repeat(message, 2)
109+
)
110+
111+
defer listener.Close()
112+
defer cleanup()
113+
114+
registerTestingService(server, testImpl)
115+
116+
go server.Serve(ctx, listener)
117+
defer server.Shutdown(ctx)
118+
119+
request := &internal.TestPayload{
120+
Foo: message,
121+
}
122+
response := &internal.TestPayload{}
123+
if err := client.Call(ctx, serviceName, "Test", request, response); err != nil {
124+
t.Fatalf("unexpected error: %v", err)
125+
}
126+
127+
if !reflect.DeepEqual(recorded, expected) {
128+
t.Fatalf("unexpected ttrpc chained client unary interceptor order (%s != %s)",
129+
strings.Join(recorded, " "), strings.Join(expected, " "))
130+
}
131+
132+
if response.Foo != reply {
133+
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
134+
}
135+
}

0 commit comments

Comments
 (0)