diff --git a/zstd.go b/zstd.go index 6073ce7c4..8897de45c 100644 --- a/zstd.go +++ b/zstd.go @@ -1,15 +1,12 @@ package sarama import ( + "runtime" "sync" "github.com/klauspost/compress/zstd" ) -// zstdMaxBufferedEncoders maximum number of not-in-use zstd encoders -// If the pool of encoders is exhausted then new encoders will be created on the fly -const zstdMaxBufferedEncoders = 1 - type ZstdEncoderParams struct { Level int } @@ -20,35 +17,65 @@ var zstdDecMap sync.Map var zstdAvailableEncoders sync.Map +var zstdCheckedOutEncoders int +var zstdMutex = &sync.Mutex{} +var zstdEncoderReturned = sync.NewCond(zstdMutex) +var zstdTestingDisableConcurrencyLimit bool + func getZstdEncoderChannel(params ZstdEncoderParams) chan *zstd.Encoder { if c, ok := zstdAvailableEncoders.Load(params); ok { return c.(chan *zstd.Encoder) } - c, _ := zstdAvailableEncoders.LoadOrStore(params, make(chan *zstd.Encoder, zstdMaxBufferedEncoders)) + limit := runtime.GOMAXPROCS(0) + c, _ := zstdAvailableEncoders.LoadOrStore(params, make(chan *zstd.Encoder, limit)) return c.(chan *zstd.Encoder) } +func newZstdEncoder(params ZstdEncoderParams) *zstd.Encoder { + encoderLevel := zstd.SpeedDefault + if params.Level != CompressionLevelDefault { + encoderLevel = zstd.EncoderLevelFromZstd(params.Level) + } + zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(encoderLevel), + zstd.WithEncoderConcurrency(1)) + return zstdEnc +} + func getZstdEncoder(params ZstdEncoderParams) *zstd.Encoder { + + zstdMutex.Lock() + defer zstdMutex.Unlock() + + limit := runtime.GOMAXPROCS(0) + for zstdCheckedOutEncoders >= limit && !zstdTestingDisableConcurrencyLimit { + zstdEncoderReturned.Wait() + limit = runtime.GOMAXPROCS(0) + } + + zstdCheckedOutEncoders += 1 + select { case enc := <-getZstdEncoderChannel(params): return enc default: - encoderLevel := zstd.SpeedDefault - if params.Level != CompressionLevelDefault { - encoderLevel = zstd.EncoderLevelFromZstd(params.Level) - } - zstdEnc, _ := zstd.NewWriter(nil, zstd.WithZeroFrames(true), - zstd.WithEncoderLevel(encoderLevel), - zstd.WithEncoderConcurrency(1)) - return zstdEnc + return newZstdEncoder(params) } } func releaseEncoder(params ZstdEncoderParams, enc *zstd.Encoder) { + zstdMutex.Lock() + + zstdCheckedOutEncoders -= 1 + select { case getZstdEncoderChannel(params) <- enc: default: } + + zstdEncoderReturned.Signal() + + zstdMutex.Unlock() } func getDecoder(params ZstdDecoderParams) *zstd.Decoder { diff --git a/zstd_test.go b/zstd_test.go index efdc6d83d..d8a9f96dd 100644 --- a/zstd_test.go +++ b/zstd_test.go @@ -2,9 +2,15 @@ package sarama import ( "runtime" + "sync" "testing" ) +// BenchmarkZstdMemoryConsumption benchmarks the memory consumption of the zstd encoder under the following constraints +// 1. The encoder is created with a high compression level +// 2. The encoder is used to compress a 1MB buffer +// 3. We emulate a 96 core system +// In other words: we test the compression memory and cpu efficiency under minimal parallelism func BenchmarkZstdMemoryConsumption(b *testing.B) { params := ZstdEncoderParams{Level: 9} buf := make([]byte, 1024*1024) @@ -15,15 +21,109 @@ func BenchmarkZstdMemoryConsumption(b *testing.B) { cpus := 96 gomaxprocsBackup := runtime.GOMAXPROCS(cpus) + defer runtime.GOMAXPROCS(gomaxprocsBackup) + + b.SetBytes(int64(len(buf) * 2 * cpus)) + b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { for j := 0; j < 2*cpus; j++ { _, _ = zstdCompress(params, nil, buf) } - // drain the buffered encoder - getZstdEncoder(params) - // previously this would be achieved with - // zstdEncMap.Delete(params) + // drain the buffered encoder so that we get a fresh one for the next run + zstdAvailableEncoders.Delete(params) + } + + b.ReportMetric(float64(cpus), "(gomaxprocs)") + b.ReportMetric(float64(1), "(goroutines)") +} + +// BenchmarkZstdMemoryConsumptionConcurrency benchmarks the memory consumption of the zstd encoder under the following constraints +// 1. The encoder is created with a high compression level +// 2. The encoder is used to compress a 1MB buffer +// 3. We emulate a 2 core system +// 4. We create 1000 goroutines that compress the buffer 2 times each +// In summary: we test the compression memory and cpu efficiency under extreme concurrency +func BenchmarkZstdMemoryConsumptionConcurrency(b *testing.B) { + params := ZstdEncoderParams{Level: 9} + buf := make([]byte, 1024*1024) + for i := 0; i < len(buf); i++ { + buf[i] = byte((i / 256) + (i * 257)) + } + + cpus := 4 + goroutines := 256 + + gomaxprocsBackup := runtime.GOMAXPROCS(cpus) + defer runtime.GOMAXPROCS(gomaxprocsBackup) + + b.ReportMetric(float64(cpus), "(gomaxprocs)") + b.ResetTimer() + b.SetBytes(int64(len(buf) * goroutines)) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + // create n goroutines, wait until all start and then signal them to start compressing + var start sync.WaitGroup + var done sync.WaitGroup + start.Add(goroutines) + done.Add(goroutines) + for j := 0; j < goroutines; j++ { + go func() { + start.Done() + start.Wait() + _, _ = zstdCompress(params, nil, buf) + done.Done() + }() + zstdAvailableEncoders.Delete(params) + } + done.Wait() } - runtime.GOMAXPROCS(gomaxprocsBackup) + + b.ReportMetric(float64(cpus), "(gomaxprocs)") + b.ReportMetric(float64(goroutines), "(goroutines)") +} + +// BenchmarkZstdMemoryNoConcurrencyLimit benchmarks the encoder behavior when the concurrency limit is disabled. +func BenchmarkZstdMemoryNoConcurrencyLimit(b *testing.B) { + zstdTestingDisableConcurrencyLimit = true + defer func() { + zstdTestingDisableConcurrencyLimit = false + }() + + params := ZstdEncoderParams{Level: 9} + buf := make([]byte, 1024*1024) + for i := 0; i < len(buf); i++ { + buf[i] = byte((i / 256) + (i * 257)) + } + + cpus := 4 + goroutines := 256 + + gomaxprocsBackup := runtime.GOMAXPROCS(cpus) + defer runtime.GOMAXPROCS(gomaxprocsBackup) + + b.ReportMetric(float64(cpus), "(gomaxprocs)") + b.ResetTimer() + b.SetBytes(int64(len(buf) * goroutines)) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + // create n goroutines, wait until all start and then signal them to start compressing + var start sync.WaitGroup + var done sync.WaitGroup + start.Add(goroutines) + done.Add(goroutines) + for j := 0; j < goroutines; j++ { + go func() { + start.Done() + start.Wait() + _, _ = zstdCompress(params, nil, buf) + done.Done() + }() + zstdAvailableEncoders.Delete(params) + } + done.Wait() + } + + b.ReportMetric(float64(cpus), "(gomaxprocs)") + b.ReportMetric(float64(goroutines), "(goroutines)") }