Skip to content

Commit

Permalink
feat: Add OnError and OnSuccess to a Job.
Browse files Browse the repository at this point in the history
We can now enqueue tasks when either job succeeds or fails.

Chains now used the `OnSuccess` slice to add the next jobs.
  • Loading branch information
iamd3vil committed Mar 14, 2024
1 parent 4e05534 commit 9c5080d
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 30 deletions.
2 changes: 0 additions & 2 deletions brokers/in-memory/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ func (r *Broker) Consume(ctx context.Context, work chan []byte, queue string) {
ch, ok := r.queues[queue]
r.mu.RUnlock()

// If the queue isn't found, make a queue.
if !ok {
ch = make(chan []byte, 100)
r.mu.Lock()
r.queues[queue] = ch
r.mu.Unlock()

}

for {
Expand Down
6 changes: 3 additions & 3 deletions chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewChain(j []Job, opts ChainOpts) (Chain, error) {
// Set the on success tasks as the i+1 task,
// hence forming a "chain" of tasks.
for i := 0; i < len(j)-1; i++ {
j[i].OnSuccess = &j[i+1]
j[i].OnSuccess = append(j[i].OnSuccess, &j[i+1])
}

return Chain{Jobs: j, Opts: opts}, nil
Expand Down Expand Up @@ -114,10 +114,10 @@ checkJobs:
// to success. Otherwise update the current job and perform all the above checks.
case StatusDone:
c.PrevJobs = append(c.PrevJobs, currJob.ID)
if currJob.OnSuccessID == "" {
if len(currJob.OnSuccessIDs) == 0 {
c.Status = StatusDone
} else {
currJob, err = s.GetJob(ctx, currJob.OnSuccessID)
currJob, err = s.GetJob(ctx, currJob.OnSuccessIDs[0])
if err != nil {
return ChainMessage{}, nil
}
Expand Down
25 changes: 14 additions & 11 deletions jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ const (
// It is the responsibility of the task handler to unmarshal (if required) the payload and process it in any manner.
type Job struct {
// If task is successful, the OnSuccess jobs are enqueued.
OnSuccess *Job
OnSuccess []*Job
Task string
Payload []byte

// If task fails, the OnError jobs are enqueued.
OnError []*Job

Opts JobOpts
}

Expand All @@ -42,15 +45,15 @@ type JobOpts struct {

// Meta contains fields related to a job. These are updated when a task is consumed.
type Meta struct {
ID string
OnSuccessID string
Status string
Queue string
Schedule string
MaxRetry uint32
Retried uint32
PrevErr string
ProcessedAt time.Time
ID string
OnSuccessIDs []string
Status string
Queue string
Schedule string
MaxRetry uint32
Retried uint32
PrevErr string
ProcessedAt time.Time

// PrevJobResults contains any job result set by the previous job in a chain.
// This will be nil if the previous job doesn't set the results on JobCtx.
Expand Down Expand Up @@ -151,7 +154,7 @@ func (s *Server) enqueueWithMeta(ctx context.Context, t Job, meta Meta) (string,
}

// Set current jobs OnSuccess as next job
t.OnSuccess = &j
t.OnSuccess = append(t.OnSuccess, &j)
// Set the next job's eta according to schedule
j.Opts.ETA = sch.Next(t.Opts.ETA)
}
Expand Down
38 changes: 38 additions & 0 deletions jobs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,44 @@ func TestDeleteJob(t *testing.T) {

}

func TestJobsOnError(t *testing.T) {
var (
srv = newServer(t, taskName, MockHandler)
)

hasErrored := make(chan bool, 1)

if err := srv.RegisterTask("error", func(b []byte, jc JobCtx) error {
t.Log("error task called")
hasErrored <- true
return nil
}, TaskOpts{
Queue: "error_task",
Concurrency: 1,
}); err != nil {
t.Fatal(err)
}

j := makeJob(t, taskName, true)

errJob, _ := NewJob("error", []byte{}, JobOpts{
Queue: "error_task",
})

j.OnError = append(j.OnError, &errJob)

if _, err := srv.Enqueue(context.Background(), j); err != nil {
t.Fatalf("error enqueuing job: %v", err)
}

go srv.Start(context.Background())

b := <-hasErrored
if !b {
t.Fatalf("error job didn't enqueue")
}
}

func makeJob(t *testing.T, taskName string, doErr bool) Job {
j, err := json.Marshal(MockPayload{ShouldErr: doErr})
if err != nil {
Expand Down
42 changes: 30 additions & 12 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ func (s *Server) process(ctx context.Context, w chan []byte) {
s.log.Error("error unmarshalling task", "error", err)
break
}

// Fetch the registered task handler.
task, err := s.getHandler(msg.Job.Task)
if err != nil {
Expand Down Expand Up @@ -365,6 +366,20 @@ func (s *Server) execJob(ctx context.Context, msg JobMessage, task Task) error {
if task.opts.FailedCB != nil {
task.opts.FailedCB(taskCtx, err)
}

// If there are jobs to enqueued after failure, enqueue them.
if msg.Job.OnError != nil {
// Extract OnErrorJob into a variable to get opts.
for _, j := range msg.Job.OnError {
nj := *j
meta := DefaultMeta(nj.Opts)

if _, err = s.enqueueWithMeta(ctx, nj, meta); err != nil {
return fmt.Errorf("error enqueuing jobs after failure: %w", err)
}
}
}

// If we hit max retries, set the task status as failed.
return s.statusFailed(ctx, msg)
}
Expand All @@ -376,19 +391,22 @@ func (s *Server) execJob(ctx context.Context, msg JobMessage, task Task) error {

// If the task contains OnSuccess task (part of a chain), enqueue them.
if msg.Job.OnSuccess != nil {
// Extract OnSuccessJob into a variable to get opts.
j := msg.Job.OnSuccess
nj := *j
meta := DefaultMeta(nj.Opts)
meta.PrevJobResult, err = s.GetResult(ctx, msg.ID)
if err != nil {
return fmt.Errorf("could not get result for id (%s) : %w", msg.ID, err)
}
for _, j := range msg.Job.OnSuccess {
// Extract OnSuccessJob into a variable to get opts.
nj := *j
meta := DefaultMeta(nj.Opts)
meta.PrevJobResult, err = s.GetResult(ctx, msg.ID)
if err != nil {
return err
}

// Set the ID of the next job in the chain
onSuccessID, err := s.enqueueWithMeta(ctx, nj, meta)
if err != nil {
return err
}

// Set the ID of the next job in the chain
msg.OnSuccessID, err = s.enqueueWithMeta(ctx, nj, meta)
if err != nil {
return fmt.Errorf("could not enqueue job id (%s) : %w", msg.ID, err)
msg.OnSuccessIDs = append(msg.OnSuccessIDs, onSuccessID)
}
}

Expand Down
7 changes: 5 additions & 2 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"log/slog"
"os"
"testing"
"time"

Expand All @@ -16,11 +17,13 @@ const (
)

func newServer(t *testing.T, taskName string, handler func([]byte, JobCtx) error) *Server {
lo := slog.Default().Handler()
lo := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: slog.LevelError,
}))
srv, err := NewServer(ServerOpts{
Broker: rb.New(),
Results: rr.New(),
Logger: lo,
Logger: lo.Handler(),
})
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit 9c5080d

Please sign in to comment.