Skip to content

Commit

Permalink
Implement pkg/fanout
Browse files Browse the repository at this point in the history
  • Loading branch information
swift1337 committed Jan 8, 2025
1 parent a04244b commit c91aeea
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
66 changes: 66 additions & 0 deletions pkg/fanout/fanout.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Package fanout provides a fan-out pattern implementation.
// It allows one channel to stream data to multiple independent channels.
// Note that context handling is out of the scope of this package.
package fanout

import "sync"

const DefaultBuffer = 8

// FanOut is a fan-out pattern implementation.
// It is NOT a worker pool, so use it wisely.
type FanOut[T any] struct {
input <-chan T
outputs []chan T

// outputBuffer chan buffer size for outputs channels.
// This helps with writing to chan in case of slow consumers.
outputBuffer int

mu sync.RWMutex
}

// New constructs FanOut
func New[T any](source <-chan T, buf int) *FanOut[T] {
return &FanOut[T]{
input: source,
outputs: make([]chan T, 0),
outputBuffer: buf,
}
}

func (f *FanOut[T]) Add() <-chan T {
out := make(chan T, f.outputBuffer)

f.mu.Lock()
defer f.mu.Unlock()

f.outputs = append(f.outputs, out)

return out
}

// Start starts the fan-out process
func (f *FanOut[T]) Start() {
go func() {
// loop for new data
for data := range f.input {
f.mu.RLock()
for _, output := range f.outputs {
// note that this might spawn lots of goroutines.
// it is a naive approach, but should be more than enough for our use cases.
go func(output chan<- T) { output <- data }(output)
}
f.mu.RUnlock()
}

// at this point, the input was closed
f.mu.Lock()
defer f.mu.Unlock()
for _, out := range f.outputs {
close(out)
}

f.outputs = nil
}()
}
72 changes: 72 additions & 0 deletions pkg/fanout/fanout_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package fanout

import (
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestFanOut(t *testing.T) {
// ARRANGE
// Given an input
input := make(chan int)

// Given a fanout
f := New(input, DefaultBuffer)

// That has 3 outputs
out1 := f.Add()
out2 := f.Add()
out3 := f.Add()

// Given a wait group
wg := sync.WaitGroup{}
wg.Add(3)

// Given a sample number
var total int32

// Given a consumer
consumer := func(out <-chan int, name string, lag time.Duration) {
defer wg.Done()
var local int32
for i := range out {
// simulate some work
time.Sleep(lag)

local += int32(i)
t.Logf("%s: received %d", name, i)
}

// add only if input was closed
atomic.AddInt32(&total, local)
}

// ACT
f.Start()

// Write to the channel
go func() {
for i := 1; i <= 10; i++ {
input <- i
t.Logf("fan-out: sent %d", i)
time.Sleep(50 * time.Millisecond)
}

close(input)
}()

go consumer(out1, "out1: fast consumer", 10*time.Millisecond)
go consumer(out2, "out2: average consumer", 60*time.Millisecond)
go consumer(out3, "out3: slow consumer", 150*time.Millisecond)

wg.Wait()

// ASSERT
// Check that total is valid
// total == sum(1...10) * 3 = n(n+1)/2 * 3 = 55 * 3 = 165
require.Equal(t, int32(165), total)
}

0 comments on commit c91aeea

Please sign in to comment.