@@ -9,8 +9,9 @@ import java.nio.channels.FileChannel
99import java .nio .channels .FileChannel .MapMode
1010import java .nio .charset .Charset
1111import java .nio .file .StandardOpenOption
12- import scala .io . _
12+ import scala .annotation . unused
1313import scala .collection .mutable .WrappedArray
14+ import scala .io ._
1415import scala .util .Using
1516
1617class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
@@ -1105,7 +1106,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
11051106 }
11061107 }
11071108
1108- " streaming compressiong and decompression" should " roundtrip" in {
1109+ " streaming compression and decompression" should " roundtrip" in {
11091110 Using .Manager { use =>
11101111 val cctx = use(new ZstdCompressCtx ())
11111112 val dctx = use(new ZstdDecompressCtx ())
@@ -1149,7 +1150,7 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
11491150 decompressedBuffer.flip()
11501151
11511152 val comparison = inputBuffer.compareTo(decompressedBuffer)
1152- comparison == 0 && Zstd .decompressedSize(compressedBuffer) == size && Zstd .getFrameContentSize(compressedBuffer) == size
1153+ assert( comparison == 0 && Zstd .decompressedSize(compressedBuffer) == size && Zstd .getFrameContentSize(compressedBuffer) == size)
11531154 }
11541155 }
11551156 }.get
@@ -1211,4 +1212,180 @@ class ZstdSpec extends AnyFlatSpec with ScalaCheckPropertyChecks {
12111212 }
12121213 }
12131214 }.get
1215+
1216+ it should " be able to use a sequence producer" in {
1217+ Using .Manager { use =>
1218+ val cctx = use(new ZstdCompressCtx ())
1219+ val cctx2 = use(new ZstdCompressCtx ())
1220+ val dctx = use(new ZstdDecompressCtx ())
1221+
1222+ forAll { input : Array [Byte ] =>
1223+ {
1224+ val size = input.length
1225+ val inputBuffer = ByteBuffer .allocateDirect(size)
1226+ inputBuffer.put(input)
1227+ inputBuffer.flip()
1228+ cctx.reset()
1229+ cctx.setLevel(9 )
1230+ val seqProd = new SequenceProducer {
1231+ def getFunctionPointer (): Long = {
1232+ Zstd .getBuiltinSequenceProducer()
1233+ }
1234+
1235+ def createState (): Long = {
1236+ cctx2.getNativePtr()
1237+ }
1238+
1239+ def freeState (@ unused state : Long ) = {}
1240+ }
1241+ cctx.registerSequenceProducer(seqProd)
1242+ cctx.setValidateSequences(true )
1243+ cctx.setSequenceProducerFallback(false )
1244+ cctx.setPledgedSrcSize(size)
1245+ val compressedBuffer = ByteBuffer .allocateDirect(Zstd .compressBound(size).toInt)
1246+ while (inputBuffer.hasRemaining) {
1247+ compressedBuffer.limit(compressedBuffer.position() + 1 )
1248+ cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .CONTINUE )
1249+ }
1250+
1251+ var frameProgression = cctx.getFrameProgression()
1252+ assert(frameProgression.getIngested() == size)
1253+ assert(frameProgression.getFlushed() == compressedBuffer.position())
1254+
1255+ compressedBuffer.limit(compressedBuffer.capacity())
1256+ val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .END )
1257+ assert(done)
1258+
1259+ frameProgression = cctx.getFrameProgression()
1260+ assert(frameProgression.getConsumed() == size)
1261+
1262+ compressedBuffer.flip()
1263+ val decompressedBuffer = ByteBuffer .allocateDirect(size)
1264+ dctx.reset()
1265+ while (compressedBuffer.hasRemaining) {
1266+ if (decompressedBuffer.limit() < decompressedBuffer.position()) {
1267+ decompressedBuffer.limit(compressedBuffer.position() + 1 )
1268+ }
1269+ dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
1270+ }
1271+
1272+ inputBuffer.rewind()
1273+ compressedBuffer.rewind()
1274+ decompressedBuffer.flip()
1275+
1276+ val comparison = inputBuffer.compareTo(decompressedBuffer)
1277+ assert(comparison == 0 && Zstd .decompressedSize(compressedBuffer) == size && Zstd .getFrameContentSize(compressedBuffer) == size)
1278+ }
1279+ }
1280+ }.get
1281+ }
1282+
1283+ it should " fail with a stub sequence producer" in {
1284+ Using .Manager { use =>
1285+ val cctx = use(new ZstdCompressCtx ())
1286+
1287+ forAll(minSize(32 )) { input : Array [Byte ] =>
1288+ {
1289+ val size = input.length
1290+ val inputBuffer = ByteBuffer .allocateDirect(size)
1291+ inputBuffer.put(input)
1292+ inputBuffer.flip()
1293+ cctx.reset()
1294+ cctx.setLevel(9 )
1295+
1296+ val seqProd = new SequenceProducer {
1297+ def getFunctionPointer (): Long = {
1298+ Zstd .getStubSequenceProducer()
1299+ }
1300+
1301+ def createState (): Long = { 0 }
1302+ def freeState (@ unused state : Long ) = { 0 }
1303+ }
1304+
1305+ cctx.registerSequenceProducer(seqProd)
1306+ cctx.setValidateSequences(true )
1307+ cctx.setSequenceProducerFallback(false )
1308+ cctx.setPledgedSrcSize(size)
1309+
1310+ val compressedBuffer = ByteBuffer .allocateDirect(Zstd .compressBound(size).toInt)
1311+ try {
1312+ while (inputBuffer.hasRemaining) {
1313+ compressedBuffer.limit(compressedBuffer.position() + 1 )
1314+ cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .CONTINUE )
1315+ }
1316+ cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .END )
1317+ fail(" compression succeeded, but should have failed" )
1318+ } catch {
1319+ case _ : ZstdException => // compression should throw a ZstdException
1320+ }
1321+ }
1322+ }
1323+ }.get
1324+ }
1325+
1326+ it should " succeed with a stub sequence producer and software fallback" in {
1327+ Using .Manager { use =>
1328+ val cctx = use(new ZstdCompressCtx ())
1329+ val dctx = use(new ZstdDecompressCtx ())
1330+
1331+ forAll { input : Array [Byte ] =>
1332+ {
1333+ val size = input.length
1334+ val inputBuffer = ByteBuffer .allocateDirect(size)
1335+ inputBuffer.put(input)
1336+ inputBuffer.flip()
1337+ cctx.reset()
1338+ cctx.setLevel(9 )
1339+
1340+ val seqProd = new SequenceProducer {
1341+ def getFunctionPointer (): Long = {
1342+ Zstd .getStubSequenceProducer()
1343+ }
1344+
1345+ def createState (): Long = { 0 }
1346+ def freeState (@ unused state : Long ) = { 0 }
1347+ }
1348+
1349+ cctx.registerSequenceProducer(seqProd)
1350+ cctx.setValidateSequences(true )
1351+ cctx.setSequenceProducerFallback(true ) // !!
1352+ cctx.setPledgedSrcSize(size)
1353+
1354+ val compressedBuffer = ByteBuffer .allocateDirect(Zstd .compressBound(size).toInt)
1355+ while (inputBuffer.hasRemaining) {
1356+ compressedBuffer.limit(compressedBuffer.position() + 1 )
1357+ cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .CONTINUE )
1358+ }
1359+
1360+ var frameProgression = cctx.getFrameProgression()
1361+ assert(frameProgression.getIngested() == size)
1362+ assert(frameProgression.getFlushed() == compressedBuffer.position())
1363+
1364+ compressedBuffer.limit(compressedBuffer.capacity())
1365+ val done = cctx.compressDirectByteBufferStream(compressedBuffer, inputBuffer, EndDirective .END )
1366+ assert(done)
1367+
1368+ frameProgression = cctx.getFrameProgression()
1369+ assert(frameProgression.getConsumed() == size)
1370+
1371+ compressedBuffer.flip()
1372+ val decompressedBuffer = ByteBuffer .allocateDirect(size)
1373+ dctx.reset()
1374+ while (compressedBuffer.hasRemaining) {
1375+ if (decompressedBuffer.limit() < decompressedBuffer.position()) {
1376+ decompressedBuffer.limit(compressedBuffer.position() + 1 )
1377+ }
1378+ dctx.decompressDirectByteBufferStream(decompressedBuffer, compressedBuffer)
1379+ }
1380+
1381+ inputBuffer.rewind()
1382+ compressedBuffer.rewind()
1383+ decompressedBuffer.flip()
1384+
1385+ val comparison = inputBuffer.compareTo(decompressedBuffer)
1386+ assert(comparison == 0 && Zstd .decompressedSize(compressedBuffer) == size && Zstd .getFrameContentSize(compressedBuffer) == size)
1387+ }
1388+ }
1389+ }.get
1390+ }
12141391}
0 commit comments