-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add control interface for backpressure handling #175
Signed-off-by: tsaikd <[email protected]>
- Loading branch information
Showing
67 changed files
with
1,060 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package config | ||
|
||
import ( | ||
"context" | ||
"sync/atomic" | ||
) | ||
|
||
type Control interface { | ||
RequestPause(ctx context.Context) error | ||
RequestResume(ctx context.Context) error | ||
PauseSignal() <-chan struct{} | ||
ResumeSignal() <-chan struct{} | ||
} | ||
|
||
func (t *Config) RequestPause(ctx context.Context) error { | ||
if atomic.CompareAndSwapInt32(&t.state, stateNormal, statePause) { | ||
return t.signalPause.Broadcast(ctx) | ||
} else { | ||
return ErrorInvalidState.New(nil) | ||
} | ||
} | ||
func (t *Config) RequestResume(ctx context.Context) error { | ||
if atomic.CompareAndSwapInt32(&t.state, statePause, stateNormal) { | ||
return t.signalResume.Broadcast(ctx) | ||
} else { | ||
return ErrorInvalidState.New(nil) | ||
} | ||
} | ||
|
||
func (t *Config) PauseSignal() <-chan struct{} { return t.signalPause.Channel() } | ||
func (t *Config) ResumeSignal() <-chan struct{} { return t.signalResume.Channel() } | ||
|
||
const ( | ||
stateNormal = iota | ||
statePause | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package ctxutil | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/subchen/go-trylock/v2" | ||
) | ||
|
||
type Broadcaster struct { | ||
mutex trylock.TryLocker | ||
channel chan struct{} | ||
} | ||
|
||
func NewBroadcaster() *Broadcaster { | ||
return &Broadcaster{ | ||
mutex: trylock.New(), | ||
channel: make(chan struct{}), | ||
} | ||
} | ||
|
||
func (t *Broadcaster) Wait(ctx context.Context) error { | ||
select { | ||
case <-ctx.Done(): | ||
return context.DeadlineExceeded | ||
case <-t.Channel(): | ||
return nil | ||
} | ||
} | ||
|
||
func (t *Broadcaster) Channel() <-chan struct{} { | ||
t.mutex.RLock() | ||
defer t.mutex.RUnlock() | ||
|
||
return t.channel | ||
} | ||
|
||
// Signal wakes one goroutine waiting on broadcaster, if there is any. | ||
func (t *Broadcaster) Signal(ctx context.Context) error { | ||
if !t.mutex.RTryLock(ctx) { | ||
return context.DeadlineExceeded | ||
} | ||
defer t.mutex.RUnlock() | ||
|
||
select { | ||
case <-ctx.Done(): | ||
return context.DeadlineExceeded | ||
case t.channel <- struct{}{}: | ||
default: | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Broadcast wakes all goroutines waiting on broadcaster, if there is any. | ||
func (t *Broadcaster) Broadcast(ctx context.Context) error { | ||
newChannel := make(chan struct{}) | ||
|
||
if !t.mutex.TryLock(ctx) { | ||
return context.DeadlineExceeded | ||
} | ||
channel := t.channel | ||
t.channel = newChannel | ||
t.mutex.Unlock() | ||
|
||
// send broadcast signal | ||
close(channel) | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package ctxutil | ||
|
||
import ( | ||
"context" | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestBroadcaster(t *testing.T) { | ||
t.Parallel() | ||
assert := assert.New(t) | ||
assert.NotNil(assert) | ||
require := require.New(t) | ||
require.NotNil(require) | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||
defer cancel() | ||
|
||
cg := NewCancelGroup(ctx) | ||
b := NewBroadcaster() | ||
|
||
var broadcastCount int32 | ||
listenReady := make(chan struct{}, 1) | ||
for i := 0; i < 5; i++ { | ||
cg.Go(func(ctx context.Context) error { | ||
listenReady <- struct{}{} | ||
select { | ||
case <-ctx.Done(): | ||
t.Fatal("wait for broadcast signal timeout") | ||
case <-b.Channel(): | ||
atomic.AddInt32(&broadcastCount, 1) | ||
} | ||
|
||
return nil | ||
}) | ||
} | ||
for i := 0; i < 5; i++ { | ||
<-listenReady | ||
} | ||
|
||
require.False(Sleep(ctx, 500*time.Millisecond)) | ||
require.NoError(b.Signal(ctx)) | ||
require.False(Sleep(ctx, 500*time.Millisecond)) | ||
require.EqualValues(1, atomic.LoadInt32(&broadcastCount)) | ||
require.NoError(b.Broadcast(ctx)) | ||
|
||
require.NoError(cg.Wait()) | ||
require.EqualValues(5, atomic.LoadInt32(&broadcastCount)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package ctxutil | ||
|
||
import ( | ||
"context" | ||
"sync" | ||
"time" | ||
) | ||
|
||
func NewCancelGroup(parent context.Context) *CancelGroup { | ||
ctx, cancel := context.WithCancel(parent) | ||
|
||
return &CancelGroup{ | ||
ctx: ctx, | ||
cancel: cancel, | ||
} | ||
} | ||
|
||
type CancelGroup struct { | ||
ctx context.Context | ||
cancel func() | ||
|
||
mutex sync.Mutex | ||
done chan error | ||
|
||
wg sync.WaitGroup | ||
|
||
errOnce sync.Once | ||
err error | ||
} | ||
|
||
func (t *CancelGroup) Wait() error { | ||
t.wg.Wait() | ||
t.cancel() | ||
|
||
return t.err | ||
} | ||
|
||
func (t *CancelGroup) Done() <-chan error { | ||
t.mutex.Lock() | ||
if t.done == nil { | ||
t.done = make(chan error) | ||
go func() { | ||
t.wg.Wait() | ||
t.cancel() | ||
t.done <- t.err | ||
}() | ||
} | ||
d := t.done | ||
t.mutex.Unlock() | ||
|
||
return d | ||
} | ||
|
||
func (t *CancelGroup) Go(f func(context.Context) error) { | ||
t.wg.Add(1) | ||
|
||
go func() { | ||
defer t.wg.Done() | ||
|
||
if err := f(t.ctx); err != nil { | ||
t.CancelError(err) | ||
} | ||
}() | ||
} | ||
|
||
// GoCancel go with cancel | ||
func (t *CancelGroup) GoCancel(f func(context.Context) error) context.CancelFunc { | ||
t.wg.Add(1) | ||
|
||
ctx, cancel := context.WithCancel(t.ctx) | ||
|
||
go func() { | ||
defer t.wg.Done() | ||
|
||
if err := f(ctx); err != nil { | ||
t.CancelError(err) | ||
} | ||
}() | ||
|
||
return cancel | ||
} | ||
|
||
// GoTimeout go with timeout | ||
func (t *CancelGroup) GoTimeout(timeout time.Duration, f func(context.Context) error) context.CancelFunc { | ||
t.wg.Add(1) | ||
|
||
ctx, cancel := context.WithTimeout(t.ctx, timeout) | ||
|
||
go func() { | ||
defer t.wg.Done() | ||
|
||
if err := f(ctx); err != nil { | ||
t.CancelError(err) | ||
} | ||
}() | ||
|
||
return cancel | ||
} | ||
|
||
// Fork goroutine will disconnect context propagation | ||
func (t *CancelGroup) Fork(f func(context.Context) error) { | ||
t.wg.Add(1) | ||
|
||
go func() { | ||
defer t.wg.Done() | ||
|
||
ctx := DisconnectContext(t.ctx) | ||
|
||
if err := f(ctx); err != nil { | ||
t.CancelError(err) | ||
} | ||
}() | ||
} | ||
|
||
// ForkTimeout fork with cancel | ||
func (t *CancelGroup) ForkCancel(f func(context.Context) error) context.CancelFunc { | ||
t.wg.Add(1) | ||
|
||
ctx, cancel := DisconnectContextWithCancel(t.ctx) | ||
|
||
go func() { | ||
defer t.wg.Done() | ||
|
||
if err := f(ctx); err != nil { | ||
t.CancelError(err) | ||
} | ||
}() | ||
|
||
return cancel | ||
} | ||
|
||
// ForkTimeout fork with timeout | ||
func (t *CancelGroup) ForkTimeout(timeout time.Duration, f func(context.Context) error) context.CancelFunc { | ||
t.wg.Add(1) | ||
|
||
ctx, cancel := DisconnectContextWithTimeout(t.ctx, timeout) | ||
|
||
go func() { | ||
defer t.wg.Done() | ||
|
||
if err := f(ctx); err != nil { | ||
t.CancelError(err) | ||
} | ||
}() | ||
|
||
return cancel | ||
} | ||
|
||
func (t *CancelGroup) Context() context.Context { | ||
return t.ctx | ||
} | ||
|
||
func (t *CancelGroup) Cancel() { | ||
t.cancel() | ||
} | ||
|
||
func (t *CancelGroup) CancelError(err error) { | ||
t.errOnce.Do(func() { | ||
t.err = err | ||
t.cancel() | ||
}) | ||
} | ||
|
||
func (t *CancelGroup) Close() (err error) { | ||
t.cancel() | ||
t.wg.Wait() | ||
|
||
return t.err | ||
} |
Oops, something went wrong.