2222import io .netty .channel .ChannelPipeline ;
2323import io .netty .channel .ChannelPromise ;
2424import io .netty .handler .codec .http .DefaultFullHttpResponse ;
25- import io .netty .handler .codec .http .FullHttpRequest ;
2625import io .netty .handler .codec .http .HttpHeaderNames ;
26+ import io .netty .handler .codec .http .HttpObject ;
2727import io .netty .handler .codec .http .HttpRequest ;
2828import io .netty .handler .codec .http .HttpResponse ;
2929import io .netty .handler .codec .http .websocketx .WebSocketServerProtocolHandler .ServerHandshakeStateEvent ;
3030import io .netty .handler .ssl .SslHandler ;
31+ import io .netty .util .ReferenceCountUtil ;
3132import io .netty .util .concurrent .Future ;
3233import io .netty .util .concurrent .FutureListener ;
3334
@@ -47,6 +48,7 @@ class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapt
4748 private final WebSocketServerProtocolConfig serverConfig ;
4849 private ChannelHandlerContext ctx ;
4950 private ChannelPromise handshakePromise ;
51+ private boolean isWebSocketPath ;
5052
5153 WebSocketServerProtocolHandshakeHandler (WebSocketServerProtocolConfig serverConfig ) {
5254 this .serverConfig = checkNotNull (serverConfig , "serverConfig" );
@@ -60,60 +62,69 @@ public void handlerAdded(ChannelHandlerContext ctx) {
6062
6163 @ Override
6264 public void channelRead (final ChannelHandlerContext ctx , Object msg ) throws Exception {
63- final FullHttpRequest req = (FullHttpRequest ) msg ;
64- if (!isWebSocketPath (req )) {
65- ctx .fireChannelRead (msg );
66- return ;
67- }
65+ final HttpObject httpObject = (HttpObject ) msg ;
6866
69- try {
70- if (!GET .equals (req .method ())) {
71- sendHttpResponse (ctx , req , new DefaultFullHttpResponse (HTTP_1_1 , FORBIDDEN , ctx .alloc ().buffer (0 )));
67+ if (httpObject instanceof HttpRequest ) {
68+ final HttpRequest req = (HttpRequest ) httpObject ;
69+ isWebSocketPath = isWebSocketPath (req );
70+ if (!isWebSocketPath ) {
71+ ctx .fireChannelRead (msg );
7272 return ;
7373 }
7474
75- final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory (
76- getWebSocketLocation (ctx .pipeline (), req , serverConfig .websocketPath ()),
77- serverConfig .subprotocols (), serverConfig .decoderConfig ());
78- final WebSocketServerHandshaker handshaker = wsFactory .newHandshaker (req );
79- final ChannelPromise localHandshakePromise = handshakePromise ;
80- if (handshaker == null ) {
81- WebSocketServerHandshakerFactory .sendUnsupportedVersionResponse (ctx .channel ());
82- } else {
83- // Ensure we set the handshaker and replace this handler before we
84- // trigger the actual handshake. Otherwise we may receive websocket bytes in this handler
85- // before we had a chance to replace it.
86- //
87- // See https://github.com/netty/netty/issues/9471.
88- WebSocketServerProtocolHandler .setHandshaker (ctx .channel (), handshaker );
89- ctx .pipeline ().remove (this );
90-
91- final ChannelFuture handshakeFuture = handshaker .handshake (ctx .channel (), req );
92- handshakeFuture .addListener (new ChannelFutureListener () {
93- @ Override
94- public void operationComplete (ChannelFuture future ) {
95- if (!future .isSuccess ()) {
96- localHandshakePromise .tryFailure (future .cause ());
97- ctx .fireExceptionCaught (future .cause ());
98- } else {
99- localHandshakePromise .trySuccess ();
100- // Kept for compatibility
101- ctx .fireUserEventTriggered (
102- WebSocketServerProtocolHandler .ServerHandshakeStateEvent .HANDSHAKE_COMPLETE );
103- ctx .fireUserEventTriggered (
104- new WebSocketServerProtocolHandler .HandshakeComplete (
105- req .uri (), req .headers (), handshaker .selectedSubprotocol ()));
75+ try {
76+ if (!GET .equals (req .method ())) {
77+ sendHttpResponse (ctx , req , new DefaultFullHttpResponse (HTTP_1_1 , FORBIDDEN , ctx .alloc ().buffer (0 )));
78+ return ;
79+ }
80+
81+ final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory (
82+ getWebSocketLocation (ctx .pipeline (), req , serverConfig .websocketPath ()),
83+ serverConfig .subprotocols (), serverConfig .decoderConfig ());
84+ final WebSocketServerHandshaker handshaker = wsFactory .newHandshaker (req );
85+ final ChannelPromise localHandshakePromise = handshakePromise ;
86+ if (handshaker == null ) {
87+ WebSocketServerHandshakerFactory .sendUnsupportedVersionResponse (ctx .channel ());
88+ } else {
89+ // Ensure we set the handshaker and replace this handler before we
90+ // trigger the actual handshake. Otherwise we may receive websocket bytes in this handler
91+ // before we had a chance to replace it.
92+ //
93+ // See https://github.com/netty/netty/issues/9471.
94+ WebSocketServerProtocolHandler .setHandshaker (ctx .channel (), handshaker );
95+ ctx .pipeline ().remove (this );
96+
97+ final ChannelFuture handshakeFuture = handshaker .handshake (ctx .channel (), req );
98+ handshakeFuture .addListener (new ChannelFutureListener () {
99+ @ Override
100+ public void operationComplete (ChannelFuture future ) {
101+ if (!future .isSuccess ()) {
102+ localHandshakePromise .tryFailure (future .cause ());
103+ ctx .fireExceptionCaught (future .cause ());
104+ } else {
105+ localHandshakePromise .trySuccess ();
106+ // Kept for compatibility
107+ ctx .fireUserEventTriggered (
108+ WebSocketServerProtocolHandler .ServerHandshakeStateEvent .HANDSHAKE_COMPLETE );
109+ ctx .fireUserEventTriggered (
110+ new WebSocketServerProtocolHandler .HandshakeComplete (
111+ req .uri (), req .headers (), handshaker .selectedSubprotocol ()));
112+ }
106113 }
107- }
108- });
109- applyHandshakeTimeout ();
114+ });
115+ applyHandshakeTimeout ();
116+ }
117+ } finally {
118+ ReferenceCountUtil .release (req );
110119 }
111- } finally {
112- req .release ();
120+ } else if (!isWebSocketPath ) {
121+ ctx .fireChannelRead (msg );
122+ } else {
123+ ReferenceCountUtil .release (msg );
113124 }
114125 }
115126
116- private boolean isWebSocketPath (FullHttpRequest req ) {
127+ private boolean isWebSocketPath (HttpRequest req ) {
117128 String websocketPath = serverConfig .websocketPath ();
118129 String uri = req .uri ();
119130 boolean checkStartUri = uri .startsWith (websocketPath );
0 commit comments