-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Package fanout provides a fan-out pattern implementation. | ||
// It allows one channel to stream data to multiple independent channels. | ||
// Note that context handling is out of the scope of this package. | ||
package fanout | ||
|
||
import "sync" | ||
|
||
const DefaultBuffer = 8 | ||
|
||
// FanOut is a fan-out pattern implementation. | ||
// It is NOT a worker pool, so use it wisely. | ||
type FanOut[T any] struct { | ||
input <-chan T | ||
outputs []chan T | ||
|
||
// outputBuffer chan buffer size for outputs channels. | ||
// This helps with writing to chan in case of slow consumers. | ||
outputBuffer int | ||
|
||
mu sync.RWMutex | ||
} | ||
|
||
// New constructs FanOut | ||
func New[T any](source <-chan T, buf int) *FanOut[T] { | ||
return &FanOut[T]{ | ||
input: source, | ||
outputs: make([]chan T, 0), | ||
outputBuffer: buf, | ||
} | ||
} | ||
|
||
func (f *FanOut[T]) Add() <-chan T { | ||
out := make(chan T, f.outputBuffer) | ||
|
||
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
|
||
f.outputs = append(f.outputs, out) | ||
|
||
return out | ||
} | ||
|
||
// Start starts the fan-out process | ||
func (f *FanOut[T]) Start() { | ||
go func() { | ||
// loop for new data | ||
for data := range f.input { | ||
f.mu.RLock() | ||
for _, output := range f.outputs { | ||
// note that this might spawn lots of goroutines. | ||
// it is a naive approach, but should be more than enough for our use cases. | ||
go func(output chan<- T) { output <- data }(output) | ||
} | ||
f.mu.RUnlock() | ||
} | ||
|
||
// at this point, the input was closed | ||
f.mu.Lock() | ||
defer f.mu.Unlock() | ||
for _, out := range f.outputs { | ||
close(out) | ||
} | ||
|
||
f.outputs = nil | ||
}() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package fanout | ||
|
||
import ( | ||
"sync" | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestFanOut(t *testing.T) { | ||
// ARRANGE | ||
// Given an input | ||
input := make(chan int) | ||
|
||
// Given a fanout | ||
f := New(input, DefaultBuffer) | ||
|
||
// That has 3 outputs | ||
out1 := f.Add() | ||
out2 := f.Add() | ||
out3 := f.Add() | ||
|
||
// Given a wait group | ||
wg := sync.WaitGroup{} | ||
wg.Add(3) | ||
|
||
// Given a sample number | ||
var total int32 | ||
|
||
// Given a consumer | ||
consumer := func(out <-chan int, name string, lag time.Duration) { | ||
defer wg.Done() | ||
var local int32 | ||
for i := range out { | ||
// simulate some work | ||
time.Sleep(lag) | ||
|
||
local += int32(i) | ||
t.Logf("%s: received %d", name, i) | ||
} | ||
|
||
// add only if input was closed | ||
atomic.AddInt32(&total, local) | ||
} | ||
|
||
// ACT | ||
f.Start() | ||
|
||
// Write to the channel | ||
go func() { | ||
for i := 1; i <= 10; i++ { | ||
input <- i | ||
t.Logf("fan-out: sent %d", i) | ||
time.Sleep(50 * time.Millisecond) | ||
} | ||
|
||
close(input) | ||
}() | ||
|
||
go consumer(out1, "out1: fast consumer", 10*time.Millisecond) | ||
go consumer(out2, "out2: average consumer", 60*time.Millisecond) | ||
go consumer(out3, "out3: slow consumer", 150*time.Millisecond) | ||
|
||
wg.Wait() | ||
|
||
// ASSERT | ||
// Check that total is valid | ||
// total == sum(1...10) * 3 = n(n+1)/2 * 3 = 55 * 3 = 165 | ||
require.Equal(t, int32(165), total) | ||
} |