diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 42cc7f36ac..0e15354274 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1680,6 +1680,13 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("32k") + private[spark] val IO_COMPRESSION_ZSTD_BUFFERPOOL_ENABLED = + ConfigBuilder("spark.io.compression.zstd.bufferPool.enabled") + .doc("If true, enable buffer pool of ZSTD JNI library.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + private[spark] val IO_COMPRESSION_ZSTD_LEVEL = ConfigBuilder("spark.io.compression.zstd.level") .doc("Compression level for Zstd compression codec. Increasing the compression " + diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index fa663a32d4..7394f752f0 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -20,7 +20,7 @@ package org.apache.spark.io import java.io._ import java.util.Locale -import com.github.luben.zstd.{ZstdInputStream, ZstdOutputStream} +import com.github.luben.zstd.{NoPool, RecyclingBufferPool, ZstdInputStream, ZstdOutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream, LZ4Factory} import net.jpountz.xxhash.XXHashFactory @@ -217,22 +217,30 @@ class ZStdCompressionCodec(conf: SparkConf) extends CompressionCodec { // fastest of all with reasonably high compression ratio. private val level = conf.get(IO_COMPRESSION_ZSTD_LEVEL) + private val bufferPool = if (conf.get(IO_COMPRESSION_ZSTD_BUFFERPOOL_ENABLED)) { + RecyclingBufferPool.INSTANCE + } else { + NoPool.INSTANCE + } + override def compressedOutputStream(s: OutputStream): OutputStream = { // Wrap the zstd output stream in a buffered output stream, so that we can // avoid overhead excessive of JNI call while trying to compress small amount of data. - new BufferedOutputStream(new ZstdOutputStream(s, level), bufferSize) + val os = new ZstdOutputStream(s, bufferPool).setLevel(level) + new BufferedOutputStream(os, bufferSize) } override private[spark] def compressedContinuousOutputStream(s: OutputStream) = { // SPARK-29322: Set "closeFrameOnFlush" to 'true' to let continuous input stream not being // stuck on reading open frame. - new BufferedOutputStream(new ZstdOutputStream(s, level).setCloseFrameOnFlush(true), bufferSize) + val os = new ZstdOutputStream(s, bufferPool).setLevel(level).setCloseFrameOnFlush(true) + new BufferedOutputStream(os, bufferSize) } override def compressedInputStream(s: InputStream): InputStream = { // Wrap the zstd input stream in a buffered input stream so that we can // avoid overhead excessive of JNI call while trying to uncompress small amount of data. - new BufferedInputStream(new ZstdInputStream(s), bufferSize) + new BufferedInputStream(new ZstdInputStream(s, bufferPool), bufferSize) } override def compressedContinuousInputStream(s: InputStream): InputStream = { @@ -240,6 +248,6 @@ class ZStdCompressionCodec(conf: SparkConf) extends CompressionCodec { // Reading). By default `isContinuous` is false, and when we try to read from open frames, // `compressedInputStream` method above throws truncated error exception. This method set // `isContinuous` true to allow reading from open frames. - new BufferedInputStream(new ZstdInputStream(s).setContinuous(true), bufferSize) + new BufferedInputStream(new ZstdInputStream(s, bufferPool).setContinuous(true), bufferSize) } } diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 4b27396e6a..18520ff96a 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -22,6 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.IO_COMPRESSION_ZSTD_BUFFERPOOL_ENABLED class CompressionCodecSuite extends SparkFunSuite { val conf = new SparkConf(false) @@ -105,9 +106,12 @@ class CompressionCodecSuite extends SparkFunSuite { } test("zstd compression codec") { - val codec = CompressionCodec.createCodec(conf, classOf[ZStdCompressionCodec].getName) - assert(codec.getClass === classOf[ZStdCompressionCodec]) - testCodec(codec) + Seq("true", "false").foreach { flag => + val conf = new SparkConf(false).set(IO_COMPRESSION_ZSTD_BUFFERPOOL_ENABLED.key, flag) + val codec = CompressionCodec.createCodec(conf, classOf[ZStdCompressionCodec].getName) + assert(codec.getClass === classOf[ZStdCompressionCodec]) + testCodec(codec) + } } test("zstd compression codec short form") {