Skip to content

Commit 9d804c5

Browse files
normanmaureryawkat
andauthored
Merge commit from fork
Motivation: We should ensure our decompressing decoders will fire their buffers through the pipeliner as fast as possible and so allow the user to take ownership of these as fast as possible. This is needed to reduce the risk of OOME as otherwise a small input might produce a large amount of data that can't be processed until all the data was decompressed in a loop. Beside this we also should ensure that other handlers that uses these decompressors will not buffer all of the produced data before processing it, which was true for HTTP and HTTP2. Modifications: - Adjust affected decoders (Brotli, Zstd and ZLib) to fire buffers through the pipeline as soon as possible - Adjust HTTP / HTTP2 decompressors to do the same - Add testcase. Result: Less risk of OOME when doing decompressing Co-authored-by: yawkat <[email protected]>
1 parent edb55fd commit 9d804c5

13 files changed

Lines changed: 603 additions & 248 deletions

File tree

codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import com.aayushatharva.brotli4j.decoder.DecoderJNI;
2020
import io.netty.buffer.ByteBuf;
21-
import io.netty.buffer.ByteBufAllocator;
2221
import io.netty.channel.ChannelHandlerContext;
2322
import io.netty.handler.codec.ByteToMessageDecoder;
2423
import io.netty.util.internal.ObjectUtil;
@@ -48,6 +47,7 @@ private enum State {
4847
private final int inputBufferSize;
4948
private DecoderJNI.Wrapper decoder;
5049
private boolean destroyed;
50+
private boolean needsRead;
5151

5252
/**
5353
* Creates a new BrotliDecoder with a default 8kB input buffer
@@ -64,15 +64,16 @@ public BrotliDecoder(int inputBufferSize) {
6464
this.inputBufferSize = ObjectUtil.checkPositive(inputBufferSize, "inputBufferSize");
6565
}
6666

67-
private ByteBuf pull(ByteBufAllocator alloc) {
67+
private void forwardOutput(ChannelHandlerContext ctx) {
6868
ByteBuffer nativeBuffer = decoder.pull();
6969
// nativeBuffer actually wraps brotli's internal buffer so we need to copy its content
70-
ByteBuf copy = alloc.buffer(nativeBuffer.remaining());
70+
ByteBuf copy = ctx.alloc().buffer(nativeBuffer.remaining());
7171
copy.writeBytes(nativeBuffer);
72-
return copy;
72+
needsRead = false;
73+
ctx.fireChannelRead(copy);
7374
}
7475

75-
private State decompress(ByteBuf input, List<Object> output, ByteBufAllocator alloc) {
76+
private State decompress(ChannelHandlerContext ctx, ByteBuf input) {
7677
for (;;) {
7778
switch (decoder.getStatus()) {
7879
case DONE:
@@ -84,7 +85,7 @@ private State decompress(ByteBuf input, List<Object> output, ByteBufAllocator al
8485

8586
case NEEDS_MORE_INPUT:
8687
if (decoder.hasOutput()) {
87-
output.add(pull(alloc));
88+
forwardOutput(ctx);
8889
}
8990

9091
if (!input.isReadable()) {
@@ -98,7 +99,7 @@ private State decompress(ByteBuf input, List<Object> output, ByteBufAllocator al
9899
break;
99100

100101
case NEEDS_MORE_OUTPUT:
101-
output.add(pull(alloc));
102+
forwardOutput(ctx);
102103
break;
103104

104105
default:
@@ -123,6 +124,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
123124

124125
@Override
125126
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
127+
needsRead = true;
126128
if (destroyed) {
127129
// Skip data received after finished.
128130
in.skipBytes(in.readableBytes());
@@ -134,7 +136,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
134136
}
135137

136138
try {
137-
State state = decompress(in, out, ctx.alloc());
139+
State state = decompress(ctx, in);
138140
if (state == State.DONE) {
139141
destroy();
140142
} else if (state == State.ERROR) {
@@ -170,4 +172,15 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
170172
super.channelInactive(ctx);
171173
}
172174
}
175+
176+
@Override
177+
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
178+
// Discard bytes of the cumulation buffer if needed.
179+
discardSomeReadBytes();
180+
181+
if (needsRead && !ctx.channel().config().isAutoRead()) {
182+
ctx.read();
183+
}
184+
ctx.fireChannelReadComplete();
185+
}
173186
}

codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public class JZlibDecoder extends ZlibDecoder {
2828

2929
private final Inflater z = new Inflater();
3030
private byte[] dictionary;
31+
private boolean needsRead;
3132
private volatile boolean finished;
3233

3334
/**
@@ -131,6 +132,7 @@ public boolean isClosed() {
131132

132133
@Override
133134
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
135+
needsRead = true;
134136
if (finished) {
135137
// Skip data received after finished.
136138
in.skipBytes(in.readableBytes());
@@ -172,6 +174,14 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
172174
int outputLength = z.next_out_index - oldNextOutIndex;
173175
if (outputLength > 0) {
174176
decompressed.writerIndex(decompressed.writerIndex() + outputLength);
177+
if (maxAllocation == 0) {
178+
// If we don't limit the maximum allocations we should just
179+
// forward the buffer directly.
180+
ByteBuf buffer = decompressed;
181+
decompressed = null;
182+
needsRead = false;
183+
ctx.fireChannelRead(buffer);
184+
}
175185
}
176186

177187
switch (resultCode) {
@@ -202,10 +212,13 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
202212
}
203213
} finally {
204214
in.skipBytes(z.next_in_index - oldNextInIndex);
205-
if (decompressed.isReadable()) {
206-
out.add(decompressed);
207-
} else {
208-
decompressed.release();
215+
if (decompressed != null) {
216+
if (decompressed.isReadable()) {
217+
needsRead = false;
218+
ctx.fireChannelRead(decompressed);
219+
} else {
220+
decompressed.release();
221+
}
209222
}
210223
}
211224
} finally {
@@ -218,6 +231,17 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
218231
}
219232
}
220233

234+
@Override
235+
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
236+
// Discard bytes of the cumulation buffer if needed.
237+
discardSomeReadBytes();
238+
239+
if (needsRead && !ctx.channel().config().isAutoRead()) {
240+
ctx.read();
241+
}
242+
ctx.fireChannelReadComplete();
243+
}
244+
221245
@Override
222246
protected void decompressionBufferExhausted(ByteBuf buffer) {
223247
finished = true;

codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ private enum GzipState {
5757
private GzipState gzipState = GzipState.HEADER_START;
5858
private int flags = -1;
5959
private int xlen = -1;
60+
private boolean needsRead;
6061

6162
private volatile boolean finished;
6263

@@ -195,6 +196,7 @@ public boolean isClosed() {
195196

196197
@Override
197198
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
199+
needsRead = true;
198200
if (finished) {
199201
// Skip data received after finished.
200202
in.skipBytes(in.readableBytes());
@@ -263,7 +265,15 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
263265
if (crc != null) {
264266
crc.update(outArray, outIndex, outputLength);
265267
}
266-
} else if (inflater.needsDictionary()) {
268+
if (maxAllocation == 0) {
269+
// If we don't limit the maximum allocations we should just
270+
// forward the buffer directly.
271+
ByteBuf buffer = decompressed;
272+
decompressed = null;
273+
needsRead = false;
274+
ctx.fireChannelRead(buffer);
275+
}
276+
} else if (inflater.needsDictionary()) {
267277
if (dictionary == null) {
268278
throw new DecompressionException(
269279
"decompression failure, unable to set dictionary as non was specified");
@@ -292,10 +302,13 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
292302
} catch (DataFormatException e) {
293303
throw new DecompressionException("decompression failure", e);
294304
} finally {
295-
if (decompressed.isReadable()) {
296-
out.add(decompressed);
297-
} else {
298-
decompressed.release();
305+
if (decompressed != null) {
306+
if (decompressed.isReadable()) {
307+
needsRead = false;
308+
ctx.fireChannelRead(decompressed);
309+
} else {
310+
decompressed.release();
311+
}
299312
}
300313
}
301314
}
@@ -517,4 +530,15 @@ private static boolean looksLikeZlib(short cmf_flg) {
517530
return (cmf_flg & 0x7800) == 0x7800 &&
518531
cmf_flg % 31 == 0;
519532
}
533+
534+
@Override
535+
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
536+
// Discard bytes of the cumulation buffer if needed.
537+
discardSomeReadBytes();
538+
539+
if (needsRead && !ctx.channel().config().isAutoRead()) {
540+
ctx.read();
541+
}
542+
ctx.fireChannelReadComplete();
543+
}
520544
}

codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdDecoder.java

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
*/
1616
package io.netty.handler.codec.compression;
1717

18+
import com.github.luben.zstd.ZstdIOException;
1819
import com.github.luben.zstd.ZstdInputStreamNoFinalizer;
1920
import io.netty.buffer.ByteBuf;
2021
import io.netty.channel.ChannelHandlerContext;
2122
import io.netty.handler.codec.ByteToMessageDecoder;
23+
import io.netty.util.internal.ObjectUtil;
2224

2325
import java.io.Closeable;
2426
import java.io.IOException;
@@ -39,9 +41,11 @@ public final class ZstdDecoder extends ByteToMessageDecoder {
3941
}
4042
}
4143

44+
private final int maximumAllocationSize;
4245
private final MutableByteBufInputStream inputStream = new MutableByteBufInputStream();
4346
private ZstdInputStreamNoFinalizer zstdIs;
4447

48+
private boolean needsRead;
4549
private State currentState = State.DECOMPRESS_DATA;
4650

4751
/**
@@ -52,31 +56,55 @@ private enum State {
5256
CORRUPTED
5357
}
5458

59+
public ZstdDecoder() {
60+
this(4 * 1024 * 1024);
61+
}
62+
63+
public ZstdDecoder(int maximumAllocationSize) {
64+
this.maximumAllocationSize = ObjectUtil.checkPositiveOrZero(maximumAllocationSize, "maximumAllocationSize");
65+
}
66+
5567
@Override
5668
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
69+
needsRead = true;
5770
try {
5871
if (currentState == State.CORRUPTED) {
5972
in.skipBytes(in.readableBytes());
73+
6074
return;
6175
}
62-
final int compressedLength = in.readableBytes();
63-
6476
inputStream.current = in;
6577

6678
ByteBuf outBuffer = null;
79+
80+
final int compressedLength = in.readableBytes();
6781
try {
82+
long uncompressedLength;
83+
if (in.isDirect()) {
84+
uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
85+
CompressionUtil.safeNioBuffer(in, in.readerIndex(), in.readableBytes()));
86+
} else {
87+
uncompressedLength = com.github.luben.zstd.Zstd.getFrameContentSize(
88+
in.array(), in.readerIndex() + in.arrayOffset(), in.readableBytes());
89+
}
90+
if (uncompressedLength <= 0) {
91+
// Let's start with the compressedLength * 2 as often we will not have everything
92+
// we need in the in buffer and don't want to reserve too much memory.
93+
uncompressedLength = compressedLength * 2L;
94+
}
95+
6896
int w;
6997
do {
7098
if (outBuffer == null) {
71-
// Let's start with the compressedLength * 2 as often we will not have everything
72-
// we need in the in buffer and don't want to reserve too much memory.
73-
outBuffer = ctx.alloc().heapBuffer(compressedLength * 2);
99+
outBuffer = ctx.alloc().heapBuffer((int) (maximumAllocationSize == 0 ?
100+
uncompressedLength : Math.min(maximumAllocationSize, uncompressedLength)));
74101
}
75102
do {
76103
w = outBuffer.writeBytes(zstdIs, outBuffer.writableBytes());
77104
} while (w != -1 && outBuffer.isWritable());
78105
if (outBuffer.isReadable()) {
79-
out.add(outBuffer);
106+
needsRead = false;
107+
ctx.fireChannelRead(outBuffer);
80108
outBuffer = null;
81109
}
82110
} while (w != -1);
@@ -93,6 +121,17 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
93121
}
94122
}
95123

124+
@Override
125+
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
126+
// Discard bytes of the cumulation buffer if needed.
127+
discardSomeReadBytes();
128+
129+
if (needsRead && !ctx.channel().config().isAutoRead()) {
130+
ctx.read();
131+
}
132+
ctx.fireChannelReadComplete();
133+
}
134+
96135
@Override
97136
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
98137
super.handlerAdded(ctx);

0 commit comments

Comments
 (0)