Skip to content

Commit 3d51bdc

Browse files
committed
Fix infinite loop in continuous mode when reading an incomplete frame.
1 parent 3d16e51 commit 3d51bdc

2 files changed

Lines changed: 51 additions & 1 deletion

File tree

src/main/java/com/github/luben/zstd/ZstdInputStream.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,11 @@ int readInternal(byte[] dst, int offset, int len) throws IOException {
149149
if (frameFinished) {
150150
return -1;
151151
} else if (isContinuous) {
152-
return (int)(dstPos - offset);
152+
srcSize = (int)(dstPos - offset);
153+
if (srcSize > 0) {
154+
return (int) srcSize;
155+
}
156+
return -1;
153157
} else {
154158
throw new IOException("Read error or truncated source");
155159
}

src/test/scala/Zstd.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,52 @@ class ZstdSpec extends FlatSpec with Checkers {
375375
input.toSeq == output.toSeq
376376
}
377377
}
378+
379+
"ZstdInputStream in continuous mode" should s"not block when the stream ends unexpectedly at level $level" in {
380+
check { input: Array[Byte] =>
381+
val size = input.length
382+
val os = new ByteArrayOutputStream(Zstd.compressBound(size.toLong).toInt)
383+
val zos = new ZstdOutputStream(os, level)
384+
zos.write(input)
385+
zos.close
386+
val compressed = os.toByteArray
387+
// Cut the stream arbitrarily short by returning only part of the available data at first.
388+
var releaseRemainingData = false
389+
class IncrementalInputStream(bytes: Array[Byte], truncationAmount: Int) extends ByteArrayInputStream(bytes) {
390+
var firstRead = true
391+
override def read(b: Array[Byte], off: Int, len: Int): Int = {
392+
if (firstRead) {
393+
firstRead = false
394+
super.read(b, off, Math.max(available() - truncationAmount, 0))
395+
} else if (releaseRemainingData) {
396+
super.read(b, off, truncationAmount)
397+
} else {
398+
-1
399+
}
400+
}
401+
402+
override def read(): Int = {
403+
throw new IllegalStateException()
404+
}
405+
}
406+
val arbitraryTruncationAmount = 7
407+
val is = new IncrementalInputStream(compressed, arbitraryTruncationAmount)
408+
val zis = new ZstdInputStream(is).setContinuous(true);
409+
val output = Array.fill[Byte](size)(0)
410+
// Read the incomplete data.
411+
val amountRead = Math.max(0, zis.read(output))
412+
// Read the rest of the data and assert that the entire input was decompressed.
413+
releaseRemainingData = true
414+
zis.read(output, amountRead, size - amountRead)
415+
zis.close
416+
if (input.toSeq != output.toSeq) {
417+
println(s"AT SIZE $size")
418+
println(input.toSeq + "!=" + output.toSeq)
419+
println("COMPRESSED: " + compressed.toSeq)
420+
}
421+
input.toSeq == output.toSeq
422+
}
423+
}
378424
}
379425

380426
for (level <- levels)

0 commit comments

Comments
 (0)