Skip to content

Commit 3184cdc

Browse files
committed
Don't block on SslContext creation in Java S2A client.
1 parent 655f0bd commit 3184cdc

2 files changed

Lines changed: 110 additions & 42 deletions

File tree

s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java

Lines changed: 93 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@
2020

2121
import com.google.common.annotations.VisibleForTesting;
2222
import com.google.common.net.HostAndPort;
23+
import com.google.common.util.concurrent.FutureCallback;
24+
import com.google.common.util.concurrent.Futures;
25+
import com.google.common.util.concurrent.ListenableFuture;
26+
import com.google.common.util.concurrent.ListeningExecutorService;
27+
import com.google.common.util.concurrent.MoreExecutors;
2328
import com.google.errorprone.annotations.ThreadSafe;
2429
import io.grpc.Channel;
25-
import io.grpc.ChannelLogger;
2630
import io.grpc.internal.ObjectPool;
2731
import io.grpc.netty.GrpcHttp2ConnectionHandler;
2832
import io.grpc.netty.InternalProtocolNegotiator;
@@ -33,16 +37,15 @@
3337
import io.grpc.s2a.channel.S2AGrpcChannelPool;
3438
import io.grpc.s2a.handshaker.S2AIdentity;
3539
import io.netty.channel.ChannelHandler;
40+
import io.netty.channel.ChannelHandlerAdapter;
3641
import io.netty.channel.ChannelHandlerContext;
42+
import io.netty.channel.ChannelInboundHandlerAdapter;
3743
import io.netty.handler.ssl.SslContext;
3844
import io.netty.util.AsciiString;
39-
import java.io.IOException;
40-
import java.security.GeneralSecurityException;
41-
import java.security.KeyStoreException;
42-
import java.security.NoSuchAlgorithmException;
43-
import java.security.UnrecoverableKeyException;
44-
import java.security.cert.CertificateException;
45+
import java.util.ArrayList;
46+
import java.util.List;
4547
import java.util.Optional;
48+
import java.util.concurrent.Executors;
4649
import org.checkerframework.checker.nullness.qual.Nullable;
4750

4851
/** Factory for performing negotiation of a secure channel using the S2A. */
@@ -96,6 +99,8 @@ static final class S2AProtocolNegotiator implements ProtocolNegotiator {
9699

97100
private final S2AChannelPool channelPool;
98101
private final Optional<S2AIdentity> localIdentity;
102+
private final ListeningExecutorService service =
103+
MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
99104

100105
static S2AProtocolNegotiator createForClient(
101106
S2AChannelPool channelPool, Optional<S2AIdentity> localIdentity) {
@@ -128,65 +133,111 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
128133
String hostname = getHostNameFromAuthority(grpcHandler.getAuthority());
129134
checkNotNull(hostname, "hostname should not be null.");
130135
return new S2AProtocolNegotiationHandler(
131-
InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler),
132-
grpcHandler.getNegotiationLogger(),
133-
channelPool,
134-
localIdentity,
135-
hostname,
136-
grpcHandler);
136+
grpcHandler, channelPool, localIdentity, hostname, service);
137137
}
138138

139139
@Override
140140
public void close() {
141+
service.shutdown();
141142
channelPool.close();
142143
}
143144
}
144145

146+
@VisibleForTesting
147+
static class BufferReadsHandler extends ChannelInboundHandlerAdapter {
148+
private final List<Object> reads = new ArrayList<>();
149+
private boolean readComplete;
150+
151+
public List<Object> getReads() {
152+
return reads;
153+
}
154+
155+
@Override
156+
public void channelRead(ChannelHandlerContext ctx, Object msg) {
157+
reads.add(msg);
158+
}
159+
160+
@Override
161+
public void channelReadComplete(ChannelHandlerContext ctx) {
162+
readComplete = true;
163+
}
164+
165+
@Override
166+
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
167+
for (Object msg : reads) {
168+
super.channelRead(ctx, msg);
169+
}
170+
if (readComplete) {
171+
super.channelReadComplete(ctx);
172+
}
173+
}
174+
}
175+
145176
private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler {
146177
private final S2AChannelPool channelPool;
147178
private final Optional<S2AIdentity> localIdentity;
148179
private final String hostname;
149-
private InternalProtocolNegotiator.ProtocolNegotiator negotiator;
150180
private final GrpcHttp2ConnectionHandler grpcHandler;
181+
private final ListeningExecutorService service;
151182

152183
private S2AProtocolNegotiationHandler(
153-
ChannelHandler next,
154-
ChannelLogger negotiationLogger,
184+
GrpcHttp2ConnectionHandler grpcHandler,
155185
S2AChannelPool channelPool,
156186
Optional<S2AIdentity> localIdentity,
157187
String hostname,
158-
GrpcHttp2ConnectionHandler grpcHandler) {
159-
super(next, negotiationLogger);
188+
ListeningExecutorService service) {
189+
super(
190+
// superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
191+
// handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
192+
// here and then manually add 'next' when we call fireProtocolNegotiationEvent()
193+
new ChannelHandlerAdapter() {
194+
@Override
195+
public void handlerAdded(ChannelHandlerContext ctx) {
196+
ctx.pipeline().remove(this);
197+
}
198+
},
199+
grpcHandler.getNegotiationLogger());
200+
this.grpcHandler = grpcHandler;
160201
this.channelPool = channelPool;
161202
this.localIdentity = localIdentity;
162203
this.hostname = hostname;
163-
this.grpcHandler = grpcHandler;
204+
checkNotNull(service, "service should not be null.");
205+
this.service = service;
164206
}
165207

166208
@Override
167-
protected void handlerAdded0(ChannelHandlerContext ctx) throws GeneralSecurityException {
168-
SslContext sslContext;
169-
try {
170-
// Establish a stream to S2A server.
171-
Channel ch = channelPool.getChannel();
172-
S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(ch);
173-
S2AStub s2aStub = S2AStub.newInstance(stub);
174-
sslContext = SslContextFactory.createForClient(s2aStub, hostname, localIdentity);
175-
} catch (InterruptedException
176-
| IOException
177-
| IllegalArgumentException
178-
| UnrecoverableKeyException
179-
| CertificateException
180-
| NoSuchAlgorithmException
181-
| KeyStoreException e) {
182-
// GeneralSecurityException is intentionally not caught, and rather propagated. This is done
183-
// because throwing a GeneralSecurityException in this context indicates that we encountered
184-
// a retryable error.
185-
throw new IllegalArgumentException(
186-
"Something went wrong during the initialization of SslContext.", e);
187-
}
188-
negotiator = InternalProtocolNegotiators.tls(sslContext);
189-
ctx.pipeline().addBefore(ctx.name(), /* name= */ null, negotiator.newHandler(grpcHandler));
209+
protected void handlerAdded0(ChannelHandlerContext ctx) {
210+
// Buffer all reads until the TLS Handler is added.
211+
BufferReadsHandler bufferReads = new BufferReadsHandler();
212+
ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads);
213+
214+
Channel ch = channelPool.getChannel();
215+
S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(ch);
216+
S2AStub s2aStub = S2AStub.newInstance(stub);
217+
218+
ListenableFuture<SslContext> sslContextFuture =
219+
service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity));
220+
Futures.addCallback(
221+
sslContextFuture,
222+
new FutureCallback<SslContext>() {
223+
@Override
224+
public void onSuccess(SslContext sslContext) {
225+
ChannelHandler handler =
226+
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
227+
228+
// Remove the bufferReads handler and delegate the rest of the handshake to the TLS
229+
// handler.
230+
ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler);
231+
fireProtocolNegotiationEvent(ctx);
232+
ctx.pipeline().remove(bufferReads);
233+
}
234+
235+
@Override
236+
public void onFailure(Throwable t) {
237+
ctx.fireExceptionCaught(t);
238+
}
239+
},
240+
service);
190241
}
191242
}
192243

s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import io.netty.channel.ChannelHandler;
4646
import io.netty.channel.ChannelHandlerContext;
4747
import io.netty.channel.ChannelPromise;
48+
import io.netty.channel.embedded.EmbeddedChannel;
4849
import io.netty.handler.codec.http2.Http2ConnectionDecoder;
4950
import io.netty.handler.codec.http2.Http2ConnectionEncoder;
5051
import io.netty.handler.codec.http2.Http2Settings;
@@ -85,6 +86,22 @@ public void tearDown() {
8586
fakeS2AServer.shutdown();
8687
}
8788

89+
@Test
90+
public void handlerRemoved_success() throws Exception {
91+
S2AProtocolNegotiatorFactory.BufferReadsHandler handler1 =
92+
new S2AProtocolNegotiatorFactory.BufferReadsHandler();
93+
S2AProtocolNegotiatorFactory.BufferReadsHandler handler2 =
94+
new S2AProtocolNegotiatorFactory.BufferReadsHandler();
95+
EmbeddedChannel channel = new EmbeddedChannel(handler1, handler2);
96+
channel.writeInbound("message1");
97+
channel.writeInbound("message2");
98+
channel.writeInbound("message3");
99+
assertThat(handler1.getReads()).hasSize(3);
100+
assertThat(handler2.getReads()).isEmpty();
101+
channel.pipeline().remove(handler1);
102+
assertThat(handler2.getReads()).hasSize(3);
103+
}
104+
88105
@Test
89106
public void createProtocolNegotiatorFactory_nullArgument() throws Exception {
90107
NullPointerTester tester = new NullPointerTester().setDefault(Optional.class, Optional.empty());

0 commit comments

Comments
 (0)