@@ -375,6 +375,52 @@ class ZstdSpec extends FlatSpec with Checkers {
375375 input.toSeq == output.toSeq
376376 }
377377 }
378+
379+ " ZstdInputStream in continuous mode" should s " not block when the stream ends unexpectedly at level $level" in {
380+ check { input : Array [Byte ] =>
381+ val size = input.length
382+ val os = new ByteArrayOutputStream (Zstd .compressBound(size.toLong).toInt)
383+ val zos = new ZstdOutputStream (os, level)
384+ zos.write(input)
385+ zos.close
386+ val compressed = os.toByteArray
387+ // Cut the stream arbitrarily short by returning only part of the available data at first.
388+ var releaseRemainingData = false
389+ class IncrementalInputStream (bytes : Array [Byte ], truncationAmount : Int ) extends ByteArrayInputStream (bytes) {
390+ var firstRead = true
391+ override def read (b : Array [Byte ], off : Int , len : Int ): Int = {
392+ if (firstRead) {
393+ firstRead = false
394+ super .read(b, off, Math .max(available() - truncationAmount, 0 ))
395+ } else if (releaseRemainingData) {
396+ super .read(b, off, truncationAmount)
397+ } else {
398+ - 1
399+ }
400+ }
401+
402+ override def read (): Int = {
403+ throw new IllegalStateException ()
404+ }
405+ }
406+ val arbitraryTruncationAmount = 7
407+ val is = new IncrementalInputStream (compressed, arbitraryTruncationAmount)
408+ val zis = new ZstdInputStream (is).setContinuous(true );
409+ val output = Array .fill[Byte ](size)(0 )
410+ // Read the incomplete data.
411+ val amountRead = Math .max(0 , zis.read(output))
412+ // Read the rest of the data and assert that the entire input was decompressed.
413+ releaseRemainingData = true
414+ zis.read(output, amountRead, size - amountRead)
415+ zis.close
416+ if (input.toSeq != output.toSeq) {
417+ println(s " AT SIZE $size" )
418+ println(input.toSeq + " !=" + output.toSeq)
419+ println(" COMPRESSED: " + compressed.toSeq)
420+ }
421+ input.toSeq == output.toSeq
422+ }
423+ }
378424 }
379425
380426 for (level <- levels)
0 commit comments