From c91aeea8e0723f760e20342bcad6bdf829b4472a Mon Sep 17 00:00:00 2001 From: Dmitry S <11892559+swift1337@users.noreply.github.com> Date: Wed, 8 Jan 2025 19:29:56 +0100 Subject: [PATCH] Implement pkg/fanout --- pkg/fanout/fanout.go | 66 +++++++++++++++++++++++++++++++++++ pkg/fanout/fanout_test.go | 72 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 pkg/fanout/fanout.go create mode 100644 pkg/fanout/fanout_test.go diff --git a/pkg/fanout/fanout.go b/pkg/fanout/fanout.go new file mode 100644 index 0000000000..7a5f277842 --- /dev/null +++ b/pkg/fanout/fanout.go @@ -0,0 +1,66 @@ +// Package fanout provides a fan-out pattern implementation. +// It allows one channel to stream data to multiple independent channels. +// Note that context handling is out of the scope of this package. +package fanout + +import "sync" + +const DefaultBuffer = 8 + +// FanOut is a fan-out pattern implementation. +// It is NOT a worker pool, so use it wisely. +type FanOut[T any] struct { + input <-chan T + outputs []chan T + + // outputBuffer chan buffer size for outputs channels. + // This helps with writing to chan in case of slow consumers. + outputBuffer int + + mu sync.RWMutex +} + +// New constructs FanOut +func New[T any](source <-chan T, buf int) *FanOut[T] { + return &FanOut[T]{ + input: source, + outputs: make([]chan T, 0), + outputBuffer: buf, + } +} + +func (f *FanOut[T]) Add() <-chan T { + out := make(chan T, f.outputBuffer) + + f.mu.Lock() + defer f.mu.Unlock() + + f.outputs = append(f.outputs, out) + + return out +} + +// Start starts the fan-out process +func (f *FanOut[T]) Start() { + go func() { + // loop for new data + for data := range f.input { + f.mu.RLock() + for _, output := range f.outputs { + // note that this might spawn lots of goroutines. + // it is a naive approach, but should be more than enough for our use cases. + go func(output chan<- T) { output <- data }(output) + } + f.mu.RUnlock() + } + + // at this point, the input was closed + f.mu.Lock() + defer f.mu.Unlock() + for _, out := range f.outputs { + close(out) + } + + f.outputs = nil + }() +} diff --git a/pkg/fanout/fanout_test.go b/pkg/fanout/fanout_test.go new file mode 100644 index 0000000000..884d122e30 --- /dev/null +++ b/pkg/fanout/fanout_test.go @@ -0,0 +1,72 @@ +package fanout + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFanOut(t *testing.T) { + // ARRANGE + // Given an input + input := make(chan int) + + // Given a fanout + f := New(input, DefaultBuffer) + + // That has 3 outputs + out1 := f.Add() + out2 := f.Add() + out3 := f.Add() + + // Given a wait group + wg := sync.WaitGroup{} + wg.Add(3) + + // Given a sample number + var total int32 + + // Given a consumer + consumer := func(out <-chan int, name string, lag time.Duration) { + defer wg.Done() + var local int32 + for i := range out { + // simulate some work + time.Sleep(lag) + + local += int32(i) + t.Logf("%s: received %d", name, i) + } + + // add only if input was closed + atomic.AddInt32(&total, local) + } + + // ACT + f.Start() + + // Write to the channel + go func() { + for i := 1; i <= 10; i++ { + input <- i + t.Logf("fan-out: sent %d", i) + time.Sleep(50 * time.Millisecond) + } + + close(input) + }() + + go consumer(out1, "out1: fast consumer", 10*time.Millisecond) + go consumer(out2, "out2: average consumer", 60*time.Millisecond) + go consumer(out3, "out3: slow consumer", 150*time.Millisecond) + + wg.Wait() + + // ASSERT + // Check that total is valid + // total == sum(1...10) * 3 = n(n+1)/2 * 3 = 55 * 3 = 165 + require.Equal(t, int32(165), total) +}