Skip to content

Commit fd05a33

Browse files
authored
WebSocketServerProtocolHandshakeHandler should work without aggregation (#11976)
Motivation: WebSocketServerProtocolHandshakeHandler assumes the msg is FullHttpRequest which is not really needed. Modifications: - Rework logic so we not depend on the fact that aggregation is used. - Add test Result: Fixes #11952
1 parent 1cbd3af commit fd05a33

File tree

2 files changed

+91
-49
lines changed

2 files changed

+91
-49
lines changed

codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
import io.netty.channel.ChannelPipeline;
2323
import io.netty.channel.ChannelPromise;
2424
import io.netty.handler.codec.http.DefaultFullHttpResponse;
25-
import io.netty.handler.codec.http.FullHttpRequest;
2625
import io.netty.handler.codec.http.HttpHeaderNames;
26+
import io.netty.handler.codec.http.HttpObject;
2727
import io.netty.handler.codec.http.HttpRequest;
2828
import io.netty.handler.codec.http.HttpResponse;
2929
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent;
3030
import io.netty.handler.ssl.SslHandler;
31+
import io.netty.util.ReferenceCountUtil;
3132
import io.netty.util.concurrent.Future;
3233
import io.netty.util.concurrent.FutureListener;
3334

@@ -47,6 +48,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
4748
private final WebSocketServerProtocolConfig serverConfig;
4849
private ChannelHandlerContext ctx;
4950
private ChannelPromise handshakePromise;
51+
private boolean isWebSocketPath;
5052

5153
WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) {
5254
this.serverConfig = checkNotNull(serverConfig, "serverConfig");
@@ -60,60 +62,69 @@ public void handlerAdded(ChannelHandlerContext ctx) {
6062

6163
@Override
6264
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
63-
final FullHttpRequest req = (FullHttpRequest) msg;
64-
if (!isWebSocketPath(req)) {
65-
ctx.fireChannelRead(msg);
66-
return;
67-
}
65+
final HttpObject httpObject = (HttpObject) msg;
6866

69-
try {
70-
if (!GET.equals(req.method())) {
71-
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN, ctx.alloc().buffer(0)));
67+
if (httpObject instanceof HttpRequest) {
68+
final HttpRequest req = (HttpRequest) httpObject;
69+
isWebSocketPath = isWebSocketPath(req);
70+
if (!isWebSocketPath) {
71+
ctx.fireChannelRead(msg);
7272
return;
7373
}
7474

75-
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
76-
getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()),
77-
serverConfig.subprotocols(), serverConfig.decoderConfig());
78-
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
79-
final ChannelPromise localHandshakePromise = handshakePromise;
80-
if (handshaker == null) {
81-
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
82-
} else {
83-
// Ensure we set the handshaker and replace this handler before we
84-
// trigger the actual handshake. Otherwise we may receive websocket bytes in this handler
85-
// before we had a chance to replace it.
86-
//
87-
// See https://github.com/netty/netty/issues/9471.
88-
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
89-
ctx.pipeline().remove(this);
90-
91-
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
92-
handshakeFuture.addListener(new ChannelFutureListener() {
93-
@Override
94-
public void operationComplete(ChannelFuture future) {
95-
if (!future.isSuccess()) {
96-
localHandshakePromise.tryFailure(future.cause());
97-
ctx.fireExceptionCaught(future.cause());
98-
} else {
99-
localHandshakePromise.trySuccess();
100-
// Kept for compatibility
101-
ctx.fireUserEventTriggered(
102-
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
103-
ctx.fireUserEventTriggered(
104-
new WebSocketServerProtocolHandler.HandshakeComplete(
105-
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
75+
try {
76+
if (!GET.equals(req.method())) {
77+
sendHttpResponse(ctx, req, new DefaultFullHttpResponse(HTTP_1_1, FORBIDDEN, ctx.alloc().buffer(0)));
78+
return;
79+
}
80+
81+
final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(
82+
getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()),
83+
serverConfig.subprotocols(), serverConfig.decoderConfig());
84+
final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);
85+
final ChannelPromise localHandshakePromise = handshakePromise;
86+
if (handshaker == null) {
87+
WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
88+
} else {
89+
// Ensure we set the handshaker and replace this handler before we
90+
// trigger the actual handshake. Otherwise we may receive websocket bytes in this handler
91+
// before we had a chance to replace it.
92+
//
93+
// See https://github.com/netty/netty/issues/9471.
94+
WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker);
95+
ctx.pipeline().remove(this);
96+
97+
final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
98+
handshakeFuture.addListener(new ChannelFutureListener() {
99+
@Override
100+
public void operationComplete(ChannelFuture future) {
101+
if (!future.isSuccess()) {
102+
localHandshakePromise.tryFailure(future.cause());
103+
ctx.fireExceptionCaught(future.cause());
104+
} else {
105+
localHandshakePromise.trySuccess();
106+
// Kept for compatibility
107+
ctx.fireUserEventTriggered(
108+
WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE);
109+
ctx.fireUserEventTriggered(
110+
new WebSocketServerProtocolHandler.HandshakeComplete(
111+
req.uri(), req.headers(), handshaker.selectedSubprotocol()));
112+
}
106113
}
107-
}
108-
});
109-
applyHandshakeTimeout();
114+
});
115+
applyHandshakeTimeout();
116+
}
117+
} finally {
118+
ReferenceCountUtil.release(req);
110119
}
111-
} finally {
112-
req.release();
120+
} else if (!isWebSocketPath) {
121+
ctx.fireChannelRead(msg);
122+
} else {
123+
ReferenceCountUtil.release(msg);
113124
}
114125
}
115126

116-
private boolean isWebSocketPath(FullHttpRequest req) {
127+
private boolean isWebSocketPath(HttpRequest req) {
117128
String websocketPath = serverConfig.websocketPath();
118129
String uri = req.uri();
119130
boolean checkStartUri = uri.startsWith(websocketPath);

codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import io.netty.channel.ChannelPromise;
2424
import io.netty.channel.embedded.EmbeddedChannel;
2525

26+
import io.netty.handler.codec.http.DefaultHttpContent;
27+
import io.netty.handler.codec.http.DefaultHttpRequest;
2628
import io.netty.handler.codec.http.HttpClientCodec;
2729
import io.netty.handler.codec.http.HttpHeaderValues;
2830
import io.netty.handler.codec.http.HttpRequestDecoder;
@@ -34,6 +36,7 @@
3436
import io.netty.handler.codec.http.HttpObjectAggregator;
3537
import io.netty.handler.codec.http.HttpServerCodec;
3638
import io.netty.handler.codec.http.HttpHeaderNames;
39+
import io.netty.handler.codec.http.LastHttpContent;
3740
import io.netty.util.CharsetUtil;
3841
import io.netty.util.ReferenceCountUtil;
3942
import org.junit.jupiter.api.BeforeEach;
@@ -60,10 +63,19 @@ public void setUp() {
6063
}
6164

6265
@Test
63-
public void testHttpUpgradeRequest() {
66+
public void testHttpUpgradeRequestFull() {
67+
testHttpUpgradeRequest0(true);
68+
}
69+
70+
@Test
71+
public void testHttpUpgradeRequestNonFull() {
72+
testHttpUpgradeRequest0(false);
73+
}
74+
75+
private void testHttpUpgradeRequest0(boolean full) {
6476
EmbeddedChannel ch = createChannel(new MockOutboundHandler());
6577
ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class);
66-
writeUpgradeRequest(ch);
78+
writeUpgradeRequest(ch, full);
6779

6880
FullHttpResponse response = responses.remove();
6981
assertEquals(SWITCHING_PROTOCOLS, response.status());
@@ -451,7 +463,26 @@ private EmbeddedChannel createChannel(WebSocketServerProtocolConfig serverConfig
451463
}
452464

453465
private static void writeUpgradeRequest(EmbeddedChannel ch) {
454-
ch.writeInbound(WebSocketRequestBuilder.successful());
466+
writeUpgradeRequest(ch, true);
467+
}
468+
469+
private static void writeUpgradeRequest(EmbeddedChannel ch, boolean full) {
470+
HttpRequest request = WebSocketRequestBuilder.successful();
471+
if (full) {
472+
ch.writeInbound(request);
473+
} else {
474+
if (request instanceof FullHttpRequest) {
475+
FullHttpRequest fullHttpRequest = (FullHttpRequest) request;
476+
HttpRequest req = new DefaultHttpRequest(fullHttpRequest.protocolVersion(), fullHttpRequest.method(),
477+
fullHttpRequest.uri(), fullHttpRequest.headers().copy());
478+
ch.writeInbound(req);
479+
ch.writeInbound(new DefaultHttpContent(fullHttpRequest.content().copy()));
480+
ch.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT);
481+
fullHttpRequest.release();
482+
} else {
483+
ch.writeInbound(request);
484+
}
485+
}
455486
}
456487

457488
private static String getResponseMessage(FullHttpResponse response) {

0 commit comments

Comments
 (0)