diff --git a/README.md b/README.md index 4f33cce..963db39 100644 --- a/README.md +++ b/README.md @@ -17,13 +17,14 @@ import ( "log" "context" "time" + "github.com/hmoragrega/workers" ) func main() { - job := func(ctx context.Context) { + job := workers.JobFunc(func(ctx context.Context) { // my job code - } + }) pool := workers.Must(workers.New(job)) @@ -31,15 +32,14 @@ func main() { log.Fatal("cannot start pool", err) } - // program continues... - - // program shutdown - ctx, cancel := context.WithTimeout(context.Background(), 5 * time.Second) - defer cancel() + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - if err := pool.Close(ctx); err != nil { - log.Fatal("cannot close pool", err) - } + if err := pool.Close(ctx); err != nil { + log.Fatal("cannot close pool", err) + } + }() } ``` @@ -149,15 +149,24 @@ The operation will fail if: A job is a simple function that accepts only one parameter, the worker context. ```go -// Job is a function that does work. -// -// The only parameter that will receive is a context, the job -// should try to honor the context cancellation signal as soon -// as possible. -// -// The context will be cancelled when removing workers from -// the pool or stopping the pool completely. -type Job = func(ctx context.Context) +// Job represent some work that needs to be done non-stop. +type Job interface { + // Do executes the job. + // + // The only parameter that will receive is a context, the job + // should try to honor the context cancellation signal as soon + // as possible. + // + // The context will be cancelled when removing workers from + // the pool or stopping the pool completely. + Do(ctx context.Context) +} +``` + +Simple jobs can use the helper `JobFunc` to comply with the interface +```go +// JobFunc is a helper function that is a job. +type JobFunc func(ctx context.Context) ``` There are two ways of extending the job functionality @@ -171,6 +180,7 @@ type Middleware interface { Next(job Job) Job } ``` + The helper `MiddlewareFunc` can be used to wrap simple middleware functions ```go diff --git a/middleware/counter.go b/middleware/counter.go index 15bd4ce..d5a98ef 100644 --- a/middleware/counter.go +++ b/middleware/counter.go @@ -3,6 +3,8 @@ package middleware import ( "context" "sync/atomic" + + "github.com/hmoragrega/workers" ) // Counter count how many jobs have started and finished. @@ -12,12 +14,12 @@ type Counter struct { } // Wrap wraps the job adding counters. -func (c *Counter) Wrap(next func(context.Context)) func(context.Context) { - return func(ctx context.Context) { +func (c *Counter) Wrap(next workers.Job) workers.Job { + return workers.JobFunc(func(ctx context.Context) { atomic.AddUint64(&c.started, 1) - next(ctx) + next.Do(ctx) atomic.AddUint64(&c.finished, 1) - } + }) } // Started returns the number of jobs that have been started. diff --git a/middleware/counter_test.go b/middleware/counter_test.go index d2c0167..b20e30c 100644 --- a/middleware/counter_test.go +++ b/middleware/counter_test.go @@ -14,14 +14,14 @@ func TestCounterMiddleware(t *testing.T) { stop = make(chan struct{}) ) - job := func(ctx context.Context) { + job := workers.JobFunc(func(ctx context.Context) { if counter.Started() == stopAt { // trigger the stop of the pool an wait for // pool context cancellation to prevent new jobs close(stop) <-ctx.Done() } - } + }) p := workers.Must(workers.New(job, &counter)) if err := p.Start(); err != nil { diff --git a/middleware/elapsed.go b/middleware/elapsed.go index df4354d..a3315cf 100644 --- a/middleware/elapsed.go +++ b/middleware/elapsed.go @@ -4,6 +4,8 @@ import ( "context" "sync" "time" + + "github.com/hmoragrega/workers" ) // Elapsed is a job middleware that extends the simple counter @@ -20,7 +22,7 @@ type Elapsed struct { } // Wrap wraps the inner job to obtain job timing metrics. -func (e *Elapsed) Wrap(next func(context.Context)) func(context.Context) { +func (e *Elapsed) Wrap(next workers.Job) workers.Job { e.mx.Lock() if e.since == nil { e.since = time.Since @@ -29,9 +31,9 @@ func (e *Elapsed) Wrap(next func(context.Context)) func(context.Context) { // wrap incoming job with the counter. next = e.Counter.Wrap(next) - return func(ctx context.Context) { + return workers.JobFunc(func(ctx context.Context) { start := time.Now() - next(ctx) + next.Do(ctx) elapsed := e.since(start) count := e.Counter.Finished() @@ -40,7 +42,7 @@ func (e *Elapsed) Wrap(next func(context.Context)) func(context.Context) { e.total += e.last e.average = e.total / time.Duration(count) e.mx.Unlock() - } + }) } // Total returns the total time spent executing diff --git a/middleware/elapsed_test.go b/middleware/elapsed_test.go index 4c0aec6..2cf56ae 100644 --- a/middleware/elapsed_test.go +++ b/middleware/elapsed_test.go @@ -18,7 +18,7 @@ func TestElapsedMiddleware(t *testing.T) { last = 10 * time.Second ) - job := func(ctx context.Context) { + job := workers.JobFunc(func(ctx context.Context) { // make every job execution 1 second longer than the previous one. elapsed.since = func(time.Time) time.Duration { return time.Second * time.Duration(elapsed.Started()) @@ -27,7 +27,7 @@ func TestElapsedMiddleware(t *testing.T) { close(stop) <-ctx.Done() } - } + }) p := workers.Must(workers.New(job, &elapsed)) if err := p.Start(); err != nil { diff --git a/middleware/wait.go b/middleware/wait.go index 39ff134..e1565a0 100644 --- a/middleware/wait.go +++ b/middleware/wait.go @@ -3,12 +3,14 @@ package middleware import ( "context" "time" + + "github.com/hmoragrega/workers" ) // Wait will add a pause between calls to the next job. // The pause affects only jobs between the same worker. -func Wait(wait time.Duration) func(func(context.Context)) func(context.Context) { - return func(job func(context.Context)) func(context.Context) { +func Wait(wait time.Duration) workers.MiddlewareFunc { + return func(job workers.Job) workers.Job { var ( ticker *time.Ticker tick <-chan time.Time @@ -21,7 +23,7 @@ func Wait(wait time.Duration) func(func(context.Context)) func(context.Context) close(ch) tick = ch } - return func(ctx context.Context) { + return workers.JobFunc(func(ctx context.Context) { select { case <-ctx.Done(): if ticker != nil { @@ -29,8 +31,8 @@ func Wait(wait time.Duration) func(func(context.Context)) func(context.Context) } return case <-tick: - job(ctx) + job.Do(ctx) } - } + }) } } diff --git a/middleware/wait_test.go b/middleware/wait_test.go index ddaba95..19ff87e 100644 --- a/middleware/wait_test.go +++ b/middleware/wait_test.go @@ -29,12 +29,12 @@ func TestWaitMiddleware_Wait(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { stop := make(chan time.Time) - job := func(ctx context.Context) { + job := workers.JobFunc(func(ctx context.Context) { stop <- time.Now() <-ctx.Done() - } + }) - p := workers.Must(workers.New(job, workers.MiddlewareFunc(Wait(tc.wait)))) + p := workers.Must(workers.New(job, Wait(tc.wait))) poolStarted := time.Now() if err := p.Start(); err != nil { @@ -56,11 +56,11 @@ func TestWaitMiddleware_Wait(t *testing.T) { func TestWaitMiddleware_Cancelled(t *testing.T) { executed := make(chan struct{}) - job := func(ctx context.Context) { + job := workers.JobFunc(func(ctx context.Context) { close(executed) - } + }) - p := workers.Must(workers.New(job, workers.MiddlewareFunc(Wait(time.Second)))) + p := workers.Must(workers.New(job, Wait(time.Second))) if err := p.Start(); err != nil { t.Fatal("cannot start pool", err) } diff --git a/pool.go b/pool.go index eacfaf7..56cd907 100644 --- a/pool.go +++ b/pool.go @@ -21,15 +21,26 @@ var ( ErrMaxReached = errors.New("maximum number of workers reached") ) -// Job is a function that does work. -// -// The only parameter that will receive is a context, the job -// should try to honor the context cancellation signal as soon -// as possible. -// -// The context will be cancelled when removing workers from -// the pool or stopping the pool completely. -type Job = func(ctx context.Context) +// Job represents some work that needs to be done non-stop. +type Job interface { + // Do executes the job. + // + // The only parameter that will receive is the worker context, + // the job should try to honor the context cancellation signal + // as soon as possible. + // + // The context will be cancelled when removing workers from + // the pool or stopping the pool completely. + Do(ctx context.Context) +} + +// JobFunc is a helper function that is a job. +type JobFunc func(ctx context.Context) + +// Do executes the job work. +func (f JobFunc) Do(ctx context.Context) { + f(ctx) +} // Middleware is a function that wraps the job and can // be used to extend the functionality of the pool. @@ -309,7 +320,7 @@ func (w *worker) work(ctx context.Context, job Job, stopped chan<- struct{}) { case <-ctx.Done(): return default: - job(ctx) + job.Do(ctx) } } } diff --git a/pool_test.go b/pool_test.go index a3c0e7c..4967e66 100644 --- a/pool_test.go +++ b/pool_test.go @@ -12,10 +12,10 @@ import ( ) var ( - dummyJob = func(_ context.Context) {} - slowJob = func(_ context.Context) { + dummyJob = JobFunc(func(_ context.Context) {}) + slowJob = JobFunc(func(_ context.Context) { <-time.NewTimer(150 * time.Millisecond).C - } + }) ) func TestPool_New(t *testing.T) { @@ -321,12 +321,12 @@ func TestPool_Close(t *testing.T) { t.Run("close timeout error", func(t *testing.T) { running := make(chan struct{}) - p := Must(New(func(_ context.Context) { + p := Must(New(JobFunc(func(_ context.Context) { // signal that we are running the job running <- struct{}{} // block the job so the call to close times out running <- struct{}{} - })) + }))) if err := p.Start(); err != nil { t.Fatalf("unexpected error starting pool: %+v", err) } @@ -345,10 +345,10 @@ func TestPool_Close(t *testing.T) { }) t.Run("close cancelled error", func(t *testing.T) { - p := Must(New(func(_ context.Context) { + p := Must(New(JobFunc(func(_ context.Context) { block := make(chan struct{}) <-block - })) + }))) if err := p.Start(); err != nil { t.Fatalf("unexpected error starting pool: %+v", err) } @@ -423,7 +423,7 @@ func TestPool_ConcurrencySafety(t *testing.T) { startRemoving = make(chan struct{}) ) - p := Must(New(func(ctx context.Context) { + p := Must(New(JobFunc(func(ctx context.Context) { if atomic.AddUint32(&count, 1) == uint32(headStart) { close(startRemoving) } @@ -431,7 +431,7 @@ func TestPool_ConcurrencySafety(t *testing.T) { case <-ctx.Done(): case <-time.NewTimer(100 * time.Millisecond).C: } - })) + }))) if err := p.Start(); err != nil { t.Fatal("cannot start pool", err)