Skip to content

Commit 40f227d

Browse files
committed
server: implement UnaryServerInterceptor chaining.
Add a WithChainUnaryServerInterceptor server option to allow using more that one server side interceptor which will then get chained and invoked in the order given. This should allow us to implement opentelemetry instrumentation as interceptors while allowing users to keep intercepting their server side calls for other reasons at the same time. Signed-off-by: Krisztian Litkey <[email protected]>
1 parent f984c9b commit 40f227d

2 files changed

Lines changed: 146 additions & 2 deletions

File tree

config.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717
package ttrpc
1818

19-
import "errors"
19+
import (
20+
"context"
21+
"errors"
22+
)
2023

2124
type serverConfig struct {
2225
handshaker Handshaker
@@ -44,9 +47,40 @@ func WithServerHandshaker(handshaker Handshaker) ServerOpt {
4447
func WithUnaryServerInterceptor(i UnaryServerInterceptor) ServerOpt {
4548
return func(c *serverConfig) error {
4649
if c.interceptor != nil {
47-
return errors.New("only one interceptor allowed per server")
50+
return errors.New("only one unchained interceptor allowed per server")
4851
}
4952
c.interceptor = i
5053
return nil
5154
}
5255
}
56+
57+
// WithChainUnaryServerInterceptor sets the provided chain of server interceptors
58+
func WithChainUnaryServerInterceptor(interceptors ...UnaryServerInterceptor) ServerOpt {
59+
return func(c *serverConfig) error {
60+
if len(interceptors) == 0 {
61+
return nil
62+
}
63+
if c.interceptor != nil {
64+
interceptors = append([]UnaryServerInterceptor{c.interceptor}, interceptors...)
65+
}
66+
c.interceptor = func(
67+
ctx context.Context,
68+
unmarshal Unmarshaler,
69+
info *UnaryServerInfo,
70+
method Method) (interface{}, error) {
71+
return interceptors[0](ctx, unmarshal, info,
72+
chainUnaryServerInterceptors(info, method, interceptors[1:]))
73+
}
74+
return nil
75+
}
76+
}
77+
78+
func chainUnaryServerInterceptors(info *UnaryServerInfo, method Method, interceptors []UnaryServerInterceptor) Method {
79+
if len(interceptors) == 0 {
80+
return method
81+
}
82+
return func(ctx context.Context, unmarshal func(interface{}) error) (interface{}, error) {
83+
return interceptors[0](ctx, unmarshal, info,
84+
chainUnaryServerInterceptors(info, method, interceptors[1:]))
85+
}
86+
}

interceptor_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,113 @@ func TestChainUnaryClientInterceptor(t *testing.T) {
133133
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
134134
}
135135
}
136+
137+
func TestUnaryServerInterceptor(t *testing.T) {
138+
var (
139+
intercepted = false
140+
interceptor = func(ctx context.Context, unmarshal Unmarshaler, _ *UnaryServerInfo, method Method) (interface{}, error) {
141+
intercepted = true
142+
return method(ctx, unmarshal)
143+
}
144+
145+
ctx = context.Background()
146+
server = mustServer(t)(NewServer(WithUnaryServerInterceptor(interceptor)))
147+
testImpl = &testingServer{}
148+
addr, listener = newTestListener(t)
149+
client, cleanup = newTestClient(t, addr)
150+
message = strings.Repeat("a", 16)
151+
reply = strings.Repeat(message, 2)
152+
)
153+
154+
defer listener.Close()
155+
defer cleanup()
156+
157+
registerTestingService(server, testImpl)
158+
159+
go server.Serve(ctx, listener)
160+
defer server.Shutdown(ctx)
161+
162+
request := &internal.TestPayload{
163+
Foo: message,
164+
}
165+
response := &internal.TestPayload{}
166+
if err := client.Call(ctx, serviceName, "Test", request, response); err != nil {
167+
t.Fatalf("unexpected error: %v", err)
168+
}
169+
170+
if !intercepted {
171+
t.Fatalf("ttrpc server call not intercepted")
172+
}
173+
174+
if response.Foo != reply {
175+
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
176+
}
177+
}
178+
179+
func TestChainUnaryServerInterceptor(t *testing.T) {
180+
var (
181+
orderIdx = 0
182+
recorded = []string{}
183+
intercept = func(idx int, tag string) UnaryServerInterceptor {
184+
return func(ctx context.Context, unmarshal Unmarshaler, _ *UnaryServerInfo, method Method) (interface{}, error) {
185+
if orderIdx != idx {
186+
t.Fatalf("unexpected interceptor invocation order (%d != %d)", orderIdx, idx)
187+
}
188+
recorded = append(recorded, tag)
189+
orderIdx++
190+
return method(ctx, unmarshal)
191+
}
192+
}
193+
194+
ctx = context.Background()
195+
server = mustServer(t)(NewServer(
196+
WithUnaryServerInterceptor(
197+
intercept(0, "seen it"),
198+
),
199+
WithChainUnaryServerInterceptor(
200+
intercept(1, "been"),
201+
intercept(2, "there"),
202+
intercept(3, "done"),
203+
intercept(4, "that"),
204+
),
205+
))
206+
expected = []string{
207+
"seen it",
208+
"been",
209+
"there",
210+
"done",
211+
"that",
212+
}
213+
testImpl = &testingServer{}
214+
addr, listener = newTestListener(t)
215+
client, cleanup = newTestClient(t, addr)
216+
message = strings.Repeat("a", 16)
217+
reply = strings.Repeat(message, 2)
218+
)
219+
220+
defer listener.Close()
221+
defer cleanup()
222+
223+
registerTestingService(server, testImpl)
224+
225+
go server.Serve(ctx, listener)
226+
defer server.Shutdown(ctx)
227+
228+
request := &internal.TestPayload{
229+
Foo: message,
230+
}
231+
response := &internal.TestPayload{}
232+
233+
if err := client.Call(ctx, serviceName, "Test", request, response); err != nil {
234+
t.Fatalf("unexpected error: %v", err)
235+
}
236+
237+
if !reflect.DeepEqual(recorded, expected) {
238+
t.Fatalf("unexpected ttrpc chained server unary interceptor order (%s != %s)",
239+
strings.Join(recorded, " "), strings.Join(expected, " "))
240+
}
241+
242+
if response.Foo != reply {
243+
t.Fatalf("unexpected test service reply: %q != %q", response.Foo, reply)
244+
}
245+
}

0 commit comments

Comments
 (0)