|
20 | 20 |
|
21 | 21 | import com.google.common.annotations.VisibleForTesting; |
22 | 22 | import com.google.common.net.HostAndPort; |
| 23 | +import com.google.common.util.concurrent.FutureCallback; |
| 24 | +import com.google.common.util.concurrent.Futures; |
| 25 | +import com.google.common.util.concurrent.ListenableFuture; |
| 26 | +import com.google.common.util.concurrent.ListeningExecutorService; |
| 27 | +import com.google.common.util.concurrent.MoreExecutors; |
23 | 28 | import com.google.errorprone.annotations.ThreadSafe; |
24 | 29 | import io.grpc.Channel; |
25 | | -import io.grpc.ChannelLogger; |
26 | 30 | import io.grpc.internal.ObjectPool; |
27 | 31 | import io.grpc.netty.GrpcHttp2ConnectionHandler; |
28 | 32 | import io.grpc.netty.InternalProtocolNegotiator; |
|
33 | 37 | import io.grpc.s2a.channel.S2AGrpcChannelPool; |
34 | 38 | import io.grpc.s2a.handshaker.S2AIdentity; |
35 | 39 | import io.netty.channel.ChannelHandler; |
| 40 | +import io.netty.channel.ChannelHandlerAdapter; |
36 | 41 | import io.netty.channel.ChannelHandlerContext; |
| 42 | +import io.netty.channel.ChannelInboundHandlerAdapter; |
37 | 43 | import io.netty.handler.ssl.SslContext; |
38 | 44 | import io.netty.util.AsciiString; |
39 | | -import java.io.IOException; |
40 | | -import java.security.GeneralSecurityException; |
41 | | -import java.security.KeyStoreException; |
42 | | -import java.security.NoSuchAlgorithmException; |
43 | | -import java.security.UnrecoverableKeyException; |
44 | | -import java.security.cert.CertificateException; |
| 45 | +import java.util.ArrayList; |
| 46 | +import java.util.List; |
45 | 47 | import java.util.Optional; |
| 48 | +import java.util.concurrent.Executors; |
46 | 49 | import org.checkerframework.checker.nullness.qual.Nullable; |
47 | 50 |
|
48 | 51 | /** Factory for performing negotiation of a secure channel using the S2A. */ |
@@ -96,6 +99,8 @@ static final class S2AProtocolNegotiator implements ProtocolNegotiator { |
96 | 99 |
|
97 | 100 | private final S2AChannelPool channelPool; |
98 | 101 | private final Optional<S2AIdentity> localIdentity; |
| 102 | + private final ListeningExecutorService service = |
| 103 | + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); |
99 | 104 |
|
100 | 105 | static S2AProtocolNegotiator createForClient( |
101 | 106 | S2AChannelPool channelPool, Optional<S2AIdentity> localIdentity) { |
@@ -128,65 +133,111 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { |
128 | 133 | String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); |
129 | 134 | checkNotNull(hostname, "hostname should not be null."); |
130 | 135 | return new S2AProtocolNegotiationHandler( |
131 | | - InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler), |
132 | | - grpcHandler.getNegotiationLogger(), |
133 | | - channelPool, |
134 | | - localIdentity, |
135 | | - hostname, |
136 | | - grpcHandler); |
| 136 | + grpcHandler, channelPool, localIdentity, hostname, service); |
137 | 137 | } |
138 | 138 |
|
139 | 139 | @Override |
140 | 140 | public void close() { |
| 141 | + service.shutdown(); |
141 | 142 | channelPool.close(); |
142 | 143 | } |
143 | 144 | } |
144 | 145 |
|
| 146 | + @VisibleForTesting |
| 147 | + static class BufferReadsHandler extends ChannelInboundHandlerAdapter { |
| 148 | + private final List<Object> reads = new ArrayList<>(); |
| 149 | + private boolean readComplete; |
| 150 | + |
| 151 | + public List<Object> getReads() { |
| 152 | + return reads; |
| 153 | + } |
| 154 | + |
| 155 | + @Override |
| 156 | + public void channelRead(ChannelHandlerContext ctx, Object msg) { |
| 157 | + reads.add(msg); |
| 158 | + } |
| 159 | + |
| 160 | + @Override |
| 161 | + public void channelReadComplete(ChannelHandlerContext ctx) { |
| 162 | + readComplete = true; |
| 163 | + } |
| 164 | + |
| 165 | + @Override |
| 166 | + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { |
| 167 | + for (Object msg : reads) { |
| 168 | + super.channelRead(ctx, msg); |
| 169 | + } |
| 170 | + if (readComplete) { |
| 171 | + super.channelReadComplete(ctx); |
| 172 | + } |
| 173 | + } |
| 174 | + } |
| 175 | + |
145 | 176 | private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { |
146 | 177 | private final S2AChannelPool channelPool; |
147 | 178 | private final Optional<S2AIdentity> localIdentity; |
148 | 179 | private final String hostname; |
149 | | - private InternalProtocolNegotiator.ProtocolNegotiator negotiator; |
150 | 180 | private final GrpcHttp2ConnectionHandler grpcHandler; |
| 181 | + private final ListeningExecutorService service; |
151 | 182 |
|
152 | 183 | private S2AProtocolNegotiationHandler( |
153 | | - ChannelHandler next, |
154 | | - ChannelLogger negotiationLogger, |
| 184 | + GrpcHttp2ConnectionHandler grpcHandler, |
155 | 185 | S2AChannelPool channelPool, |
156 | 186 | Optional<S2AIdentity> localIdentity, |
157 | 187 | String hostname, |
158 | | - GrpcHttp2ConnectionHandler grpcHandler) { |
159 | | - super(next, negotiationLogger); |
| 188 | + ListeningExecutorService service) { |
| 189 | + super( |
| 190 | + // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' |
| 191 | + // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior |
| 192 | + // here and then manually add 'next' when we call fireProtocolNegotiationEvent() |
| 193 | + new ChannelHandlerAdapter() { |
| 194 | + @Override |
| 195 | + public void handlerAdded(ChannelHandlerContext ctx) { |
| 196 | + ctx.pipeline().remove(this); |
| 197 | + } |
| 198 | + }, |
| 199 | + grpcHandler.getNegotiationLogger()); |
| 200 | + this.grpcHandler = grpcHandler; |
160 | 201 | this.channelPool = channelPool; |
161 | 202 | this.localIdentity = localIdentity; |
162 | 203 | this.hostname = hostname; |
163 | | - this.grpcHandler = grpcHandler; |
| 204 | + checkNotNull(service, "service should not be null."); |
| 205 | + this.service = service; |
164 | 206 | } |
165 | 207 |
|
166 | 208 | @Override |
167 | | - protected void handlerAdded0(ChannelHandlerContext ctx) throws GeneralSecurityException { |
168 | | - SslContext sslContext; |
169 | | - try { |
170 | | - // Establish a stream to S2A server. |
171 | | - Channel ch = channelPool.getChannel(); |
172 | | - S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(ch); |
173 | | - S2AStub s2aStub = S2AStub.newInstance(stub); |
174 | | - sslContext = SslContextFactory.createForClient(s2aStub, hostname, localIdentity); |
175 | | - } catch (InterruptedException |
176 | | - | IOException |
177 | | - | IllegalArgumentException |
178 | | - | UnrecoverableKeyException |
179 | | - | CertificateException |
180 | | - | NoSuchAlgorithmException |
181 | | - | KeyStoreException e) { |
182 | | - // GeneralSecurityException is intentionally not caught, and rather propagated. This is done |
183 | | - // because throwing a GeneralSecurityException in this context indicates that we encountered |
184 | | - // a retryable error. |
185 | | - throw new IllegalArgumentException( |
186 | | - "Something went wrong during the initialization of SslContext.", e); |
187 | | - } |
188 | | - negotiator = InternalProtocolNegotiators.tls(sslContext); |
189 | | - ctx.pipeline().addBefore(ctx.name(), /* name= */ null, negotiator.newHandler(grpcHandler)); |
| 209 | + protected void handlerAdded0(ChannelHandlerContext ctx) { |
| 210 | + // Buffer all reads until the TLS Handler is added. |
| 211 | + BufferReadsHandler bufferReads = new BufferReadsHandler(); |
| 212 | + ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads); |
| 213 | + |
| 214 | + Channel ch = channelPool.getChannel(); |
| 215 | + S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(ch); |
| 216 | + S2AStub s2aStub = S2AStub.newInstance(stub); |
| 217 | + |
| 218 | + ListenableFuture<SslContext> sslContextFuture = |
| 219 | + service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity)); |
| 220 | + Futures.addCallback( |
| 221 | + sslContextFuture, |
| 222 | + new FutureCallback<SslContext>() { |
| 223 | + @Override |
| 224 | + public void onSuccess(SslContext sslContext) { |
| 225 | + ChannelHandler handler = |
| 226 | + InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); |
| 227 | + |
| 228 | + // Remove the bufferReads handler and delegate the rest of the handshake to the TLS |
| 229 | + // handler. |
| 230 | + ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler); |
| 231 | + fireProtocolNegotiationEvent(ctx); |
| 232 | + ctx.pipeline().remove(bufferReads); |
| 233 | + } |
| 234 | + |
| 235 | + @Override |
| 236 | + public void onFailure(Throwable t) { |
| 237 | + ctx.fireExceptionCaught(t); |
| 238 | + } |
| 239 | + }, |
| 240 | + service); |
190 | 241 | } |
191 | 242 | } |
192 | 243 |
|
|
0 commit comments