From 32b4455d4a5ae4e5eedd43f3e2bf1262b94b0778 Mon Sep 17 00:00:00 2001 From: Jeremy Gonzalez Date: Wed, 13 Apr 2022 09:45:42 -0600 Subject: [PATCH] fixes to batching (#20) * split init into functions; one for new batches and one for loading batch meta * update timeout msg * remove batches with invalid metadata * remove job state keyes on deletion * update all key expirations * lock batches before accessing them * move removeBatch locks to outside function to avoid deadlock * only lock when creating / deleting references * update reference to avoid deadlock * update child cmd to ensure batch is committed if it was opened * do not fail jobs when a batch has expired * use lockBatchIfExists * remove stale batches every 24 hours * move mutex back to batch struct * remove unlock since it is deferred * use defer where possible * refactor remove stale batches to avoid deadlocks * add tests ensuring no deadlocks * switch from defer * remove child when done * close stopper properly in tests * remove goroutine * avoid deadlock when committing parent --- batch/batch.go | 278 ++++++++++++++++++++++++++++--------------- batch/batch_test.go | 12 +- batch/child.go | 23 +++- batch/commands.go | 38 ++++-- batch/middlware.go | 46 ++++--- batch/stress_test.go | 1 + batch/subsystem.go | 4 +- 7 files changed, 277 insertions(+), 125 deletions(-) diff --git a/batch/batch.go b/batch/batch.go index b2d117d..2dadcfb 100644 --- a/batch/batch.go +++ b/batch/batch.go @@ -41,7 +41,7 @@ type batchManager struct { Batches map[string]*batch Subsystem *BatchSubsystem rclient *redis.Client - mu sync.RWMutex + mu sync.Mutex // this lock is only used for to lock access to Batches } const ( @@ -53,17 +53,12 @@ const ( CallbackJobSucceeded = "2" ) -func (m *batchManager) getBatchFromInterface(batchId interface{}) (*batch, error) { +func (m *batchManager) getBatchIdFromInterface(batchId interface{}) (string, error) { bid, ok := batchId.(string) if !ok { - return nil, errors.New("getBatchFromInterface: invalid custom bid value") + return "", errors.New("getBatchIdFromInterface: invalid custom bid value") } - batch, err := m.getBatch(bid) - if err != nil { - util.Warnf("getBatchFromInterface: Unable to retrieve batch: %v", err) - return nil, fmt.Errorf("getBatchFromInterface: unable to get batch: %s", bid) - } - return batch, nil + return bid, nil } func (m *batchManager) loadExistingBatches() error { @@ -72,12 +67,7 @@ func (m *batchManager) loadExistingBatches() error { return fmt.Errorf("loadExistingBatches: retrieve batches: %v", err) } for idx := range vals { - batch, err := m.newBatch(vals[idx], &batchMeta{}) - if err != nil { - util.Warnf("loadExistingBatches: error load batch (%s) %v", vals[idx], err) - continue - } - m.Batches[vals[idx]] = batch + m.loadBatch(vals[idx]) } // update parent and children @@ -112,13 +102,35 @@ func (m *batchManager) loadExistingBatches() error { m.handleBatchJobsCompleted(b, map[string]bool{b.Id: true}) } } - + util.Infof("Loaded %d batches", len(m.Batches)) return nil } +func (m *batchManager) lockBatchIfExists(batchId string) { + m.mu.Lock() + batchToLock, ok := m.Batches[batchId] + if !ok { + m.mu.Unlock() + return + } + m.mu.Unlock() + batchToLock.mu.Lock() +} + +func (m *batchManager) unlockBatchIfExists(batchId string) { + m.mu.Lock() + batchToLock, ok := m.Batches[batchId] + if !ok { + m.mu.Unlock() + return + } + m.mu.Unlock() + batchToLock.mu.Unlock() +} + func (m *batchManager) getBatch(batchId string) (*batch, error) { - m.mu.RLock() - defer m.mu.RUnlock() + m.mu.Lock() + defer m.mu.Unlock() if batchId == "" { return nil, fmt.Errorf("getBatch: batchId cannot be blank") } @@ -136,46 +148,74 @@ func (m *batchManager) getBatch(batchId string) (*batch, error) { } if exists == 0 { m.removeBatch(b) - return nil, fmt.Errorf("getBatch: batch was not committed within 2 hours") + return nil, fmt.Errorf("getBatch: batch has timed out") } - return b, nil } func (m *batchManager) removeBatch(batch *batch) { - m.mu.Lock() + // locking must be handled outside the function if err := m.remove(batch); err != nil { util.Warnf("removeBatch: unable to remove batch: %v", err) } delete(m.Batches, batch.Id) batch = nil - m.mu.Unlock() } func (m *batchManager) removeStaleBatches() { - util.Debugf("Checking for stale batches") + // in order to avoid dead locks + // Step 1 create a list of batches to delete + // Step 2 take a lock on each batch + // - this ensures we wait for any operations on a batch to finish + // Step 3 lock access to any batch + // - this way no other locks can be taken + // Step 4 delete batches + util.Infof("checking for stale batches") + var batchesToRemove []string + // Step 1 for _, b := range m.Batches { - createdAt, err := time.Parse(time.RFC3339Nano, b.Meta.CreatedAt) - if err != nil { - continue - } remove := false - uncommittedTimeout := time.Now().Add(-time.Duration(m.Subsystem.Options.UncommittedTimeoutMinutes) * time.Minute).UTC() - committedTimeout := time.Now().AddDate(0, 0, -m.Subsystem.Options.CommittedTimeoutDays).UTC() - if !b.Meta.Committed && createdAt.Before(uncommittedTimeout) { - remove = true - } else if b.Meta.Committed && createdAt.Before(committedTimeout) { + if b.Meta.CreatedAt != "" { + createdAt, err := time.Parse(time.RFC3339Nano, b.Meta.CreatedAt) + if err != nil { + continue + } + uncommittedTimeout := time.Now().Add(-time.Duration(m.Subsystem.Options.UncommittedTimeoutMinutes) * time.Minute).UTC() + committedTimeout := time.Now().AddDate(0, 0, -m.Subsystem.Options.CommittedTimeoutDays).UTC() + if !b.Meta.Committed && createdAt.Before(uncommittedTimeout) { + remove = true + } else if b.Meta.Committed && createdAt.Before(committedTimeout) { + remove = true + } + } else { remove = true } if remove { - util.Debugf("Removing stale batch %s", b.Id) + batchesToRemove = append(batchesToRemove, b.Id) + } + } + // Step 2 - lock each batch + for _, batchId := range batchesToRemove { + if b, ok := m.Batches[batchId]; ok { b.mu.Lock() - m.removeBatch(b) - b.mu.Unlock() } } + // Step 3 - lock access to all batches + m.mu.Lock() + defer m.mu.Unlock() + + // Step 4 - delete batches and unlock (in case another goroutines is waiting on a lock) + for _, batchId := range batchesToRemove { + func() { + if b, ok := m.Batches[batchId]; ok { + defer b.mu.Unlock() + m.removeBatch(b) + } + }() + } + util.Infof("Removed: %d stale batches", len(batchesToRemove)) } func (m *batchManager) newBatchMeta(description string, success string, complete string, childSearchDepth *int) *batchMeta { @@ -195,6 +235,25 @@ func (m *batchManager) newBatchMeta(description string, success string, complete } } +func (m *batchManager) loadBatch(batchId string) { + m.mu.Lock() + defer m.mu.Unlock() + batch := &batch{ + Id: batchId, + Parents: make([]*batch, 0), + Children: make([]*batch, 0), + Meta: &batchMeta{}, + mu: sync.Mutex{}, + } + + if err := m.loadMetadata(batch); err != nil { + util.Warnf("loadExistingBatches: error load batch (%s) %v", batchId, err) + m.remove(batch) + return + } + m.Batches[batchId] = batch +} + func (m *batchManager) newBatch(batchId string, meta *batchMeta) (*batch, error) { m.mu.Lock() defer m.mu.Unlock() @@ -238,55 +297,12 @@ func (m *batchManager) getCompleteJobStateKey(batchId string) string { return fmt.Sprintf("complete-st-%s", batchId) } -func (m *batchManager) init(batch *batch) error { +func (m *batchManager) loadMetadata(batch *batch) error { meta, err := m.rclient.HGetAll(m.getMetaKey(batch.Id)).Result() if err != nil { return fmt.Errorf("init: unable to retrieve meta: %v", err) } - if err := m.rclient.SAdd("batches", batch.Id).Err(); err != nil { - return fmt.Errorf("init: store batch: %v", err) - } - - expiration := time.Duration(m.Subsystem.Options.UncommittedTimeoutMinutes) * time.Minute - if err := m.rclient.SetNX(m.getBatchKey(batch.Id), batch.Id, expiration).Err(); err != nil { - return fmt.Errorf("init: set expiration: %v", err) - } - - if len(meta) == 0 { - // set default values - data := map[string]interface{}{ - "total": batch.Meta.Total, - "failed": batch.Meta.Failed, - "succeeded": batch.Meta.Succeeded, - "pending": batch.Meta.Pending, - "created_at": batch.Meta.CreatedAt, - "description": batch.Meta.Description, - "committed": batch.Meta.Committed, - "success_job": batch.Meta.SuccessJob, - "complete_job": batch.Meta.CompleteJob, - "child_count": batch.Meta.ChildCount, - } - if batch.Meta.ChildSearchDepth != nil { - data["child_search_depth"] = *batch.Meta.ChildSearchDepth - } - if err := m.rclient.HMSet(m.getMetaKey(batch.Id), data).Err(); err != nil { - return fmt.Errorf("init: could not load meta for batch: %s: %v", batch.Id, err) - } - if err := m.rclient.Expire(m.getMetaKey(batch.Id), expiration).Err(); err != nil { - return fmt.Errorf("init: could set expiration for batch meta: %v", err) - } - - timeout := time.Duration(m.Subsystem.Options.CommittedTimeoutDays) * 24 * time.Hour - if err := m.rclient.SetNX(m.getSuccessJobStateKey(batch.Id), CallbackJobPending, timeout).Err(); err != nil { - return fmt.Errorf("init: could not set success_st: %v", err) - } - if err := m.rclient.SetNX(m.getCompleteJobStateKey(batch.Id), CallbackJobPending, timeout).Err(); err != nil { - return fmt.Errorf("init: could not set complete_st: %v", err) - } - return nil - } - batch.Meta.Total, err = strconv.Atoi(meta["total"]) if err != nil { return fmt.Errorf("init: total: failed converting string to int: %v", err) @@ -342,6 +358,48 @@ func (m *batchManager) init(batch *batch) error { return nil } +func (m *batchManager) init(batch *batch) error { + if err := m.rclient.SAdd("batches", batch.Id).Err(); err != nil { + return fmt.Errorf("init: store batch: %v", err) + } + + expiration := time.Duration(m.Subsystem.Options.UncommittedTimeoutMinutes) * time.Minute + if err := m.rclient.SetNX(m.getBatchKey(batch.Id), batch.Id, expiration).Err(); err != nil { + return fmt.Errorf("init: set expiration: %v", err) + } + + // set default values + data := map[string]interface{}{ + "total": batch.Meta.Total, + "failed": batch.Meta.Failed, + "succeeded": batch.Meta.Succeeded, + "pending": batch.Meta.Pending, + "created_at": batch.Meta.CreatedAt, + "description": batch.Meta.Description, + "committed": batch.Meta.Committed, + "success_job": batch.Meta.SuccessJob, + "complete_job": batch.Meta.CompleteJob, + "child_count": batch.Meta.ChildCount, + } + if batch.Meta.ChildSearchDepth != nil { + data["child_search_depth"] = *batch.Meta.ChildSearchDepth + } + if err := m.rclient.HMSet(m.getMetaKey(batch.Id), data).Err(); err != nil { + return fmt.Errorf("init: could not load meta for batch: %s: %v", batch.Id, err) + } + if err := m.rclient.Expire(m.getMetaKey(batch.Id), expiration).Err(); err != nil { + return fmt.Errorf("init: could set expiration for batch meta: %v", err) + } + if err := m.rclient.SetNX(m.getSuccessJobStateKey(batch.Id), CallbackJobPending, expiration).Err(); err != nil { + return fmt.Errorf("init: could not set success_st: %v", err) + } + if err := m.rclient.SetNX(m.getCompleteJobStateKey(batch.Id), CallbackJobPending, expiration).Err(); err != nil { + return fmt.Errorf("init: could not set complete_st: %v", err) + } + + return nil +} + func (m *batchManager) commit(batch *batch) error { if err := m.updateCommitted(batch, true); err != nil { return fmt.Errorf("commit: %v", err) @@ -399,6 +457,12 @@ func (m *batchManager) remove(batch *batch) error { if err := m.rclient.Del(m.getChildKey(batch.Id)).Err(); err != nil { return fmt.Errorf("remove: batch children (%s), %v", batch.Id, err) } + if err := m.rclient.Del(m.getSuccessJobStateKey(batch.Id)).Err(); err != nil { + return fmt.Errorf("remove: could delete expire success_st: %v", err) + } + if err := m.rclient.Del(m.getCompleteJobStateKey(batch.Id)).Err(); err != nil { + return fmt.Errorf("updatedCommitted: could not deletecomplete_st: %v", err) + } return nil } @@ -407,38 +471,54 @@ func (m *batchManager) updateCommitted(batch *batch, committed bool) error { if err := m.rclient.HSet(m.getMetaKey(batch.Id), "committed", committed).Err(); err != nil { return fmt.Errorf("updateCommitted: could not update committed: %v", err) } - + var expiration time.Duration if committed { - // number of days a batch can exist - if err := m.rclient.Expire(m.getBatchKey(batch.Id), time.Duration(m.Subsystem.Options.CommittedTimeoutDays)*time.Hour*24).Err(); err != nil { - return fmt.Errorf("updatedCommitted: could not not expire after committed: %v", err) - } + expiration = time.Duration(m.Subsystem.Options.CommittedTimeoutDays) * time.Hour * 24 } else { - if err := m.rclient.Expire(m.getBatchKey(batch.Id), time.Duration(m.Subsystem.Options.UncommittedTimeoutMinutes)*time.Minute).Err(); err != nil { - return fmt.Errorf("updatedCommitted: could not expire: %v", err) - } + expiration = time.Duration(m.Subsystem.Options.UncommittedTimeoutMinutes) * time.Minute + } + if err := m.rclient.Expire(m.getBatchKey(batch.Id), expiration).Err(); err != nil { + return fmt.Errorf("updatedCommitted: could not set expire for batch: %v", err) + } + if err := m.rclient.Expire(m.getMetaKey(batch.Id), expiration).Err(); err != nil { + return fmt.Errorf("updatedCommitted: could not set expire for batch meta: %v", err) } + if err := m.rclient.Expire(m.getSuccessJobStateKey(batch.Id), expiration).Err(); err != nil { + return fmt.Errorf("updatedCommitted: could not set expire success_st: %v", err) + } + if err := m.rclient.Expire(m.getCompleteJobStateKey(batch.Id), expiration).Err(); err != nil { + return fmt.Errorf("updatedCommitted: could not set expire for complete_st: %v", err) + } + return nil } func (m *batchManager) updateJobCallbackState(batch *batch, callbackType string, state string) error { // locking must be handled outside of call - timeout := time.Duration(m.Subsystem.Options.CommittedTimeoutDays) * 24 * time.Hour + expire := time.Duration(m.Subsystem.Options.CommittedTimeoutDays) * 24 * time.Hour if callbackType == "success" { batch.Meta.SuccessJobState = state - if err := m.rclient.Set(m.getSuccessJobStateKey(batch.Id), state, timeout).Err(); err != nil { + if err := m.rclient.Set(m.getSuccessJobStateKey(batch.Id), state, expire).Err(); err != nil { return fmt.Errorf("updateJobCallbackState: could not set success_st: %v", err) } if state == CallbackJobSucceeded { - m.removeBatch(batch) + func() { + m.mu.Lock() + defer m.mu.Unlock() + m.removeBatch(batch) + }() } } else { batch.Meta.CompleteJobState = state - if err := m.rclient.Set(m.getCompleteJobStateKey(batch.Id), state, timeout).Err(); err != nil { + if err := m.rclient.Set(m.getCompleteJobStateKey(batch.Id), state, expire).Err(); err != nil { return fmt.Errorf("updateJobCallbackState: could not set completed_st: %v", err) } if _, areChildrenSucceeded := m.areChildrenFinished(batch); areChildrenSucceeded && batch.Meta.SuccessJob == "" && state == CallbackJobSucceeded { - m.removeBatch(batch) + func() { + m.mu.Lock() + defer m.mu.Unlock() + m.removeBatch(batch) + }() } } return nil @@ -498,10 +578,20 @@ func (m *batchManager) handleBatchJobsCompleted(batch *batch, parentsVisited map // parent has already been notified continue } - parent.mu.Lock() + m.lockBatchIfExists(parent.Id) parentsVisited[parent.Id] = true + m.mu.Lock() + if _, ok := m.Batches[parent.Id]; !ok { + if err := m.removeParent(batch, parent); err != nil { + util.Warnf("handleBatchJobsCompleted: unable to delete parent: %v", err) + } + m.mu.Unlock() + m.unlockBatchIfExists(parent.Id) + continue + } + m.mu.Unlock() m.handleChildComplete(parent, batch, areChildrenFinished, areChildrenSucceeded, parentsVisited) - parent.mu.Unlock() + m.unlockBatchIfExists(parent.Id) } } diff --git a/batch/batch_test.go b/batch/batch_test.go index 5a5ca97..2953556 100644 --- a/batch/batch_test.go +++ b/batch/batch_test.go @@ -399,11 +399,20 @@ func TestRemoveStaleBatches(t *testing.T) { _, err = batchSystem.batchManager.newBatch(uncommittedBatchId, uncommittedMeta) assert.Nil(t, err) + batchSystem.batchManager.lockBatchIfExists(uncommittedBatchId) + go func() { + time.Sleep(1) + batchSystem.batchManager.unlockBatchIfExists(uncommittedBatchId) + batchSystem.batchManager.lockBatchIfExists(uncommittedBatchId) + time.Sleep(1) + batchSystem.batchManager.unlockBatchIfExists(uncommittedBatchId) + }() batchSystem.batchManager.removeStaleBatches() _, err = batchSystem.batchManager.getBatch(committedBatchId) assert.EqualError(t, err, "getBatch: no batch found") + batchSystem.batchManager.lockBatchIfExists(uncommittedBatchId) _, err = batchSystem.batchManager.getBatch(uncommittedBatchId) assert.EqualError(t, err, "getBatch: no batch found") }) @@ -419,7 +428,6 @@ func withServer(batchSystem *BatchSubsystem, enabled bool, runner func(cl *clien panic(err) } defer stopper() - defer s.Stop(nil) go cli.HandleSignals(s) @@ -448,6 +456,8 @@ func withServer(batchSystem *BatchSubsystem, enabled bool, runner func(cl *clien } runner(cl) + close(s.Stopper()) + s.Stop(nil) } func getClient() (*client.Client, error) { diff --git a/batch/child.go b/batch/child.go index dbb0e82..ef46e0d 100644 --- a/batch/child.go +++ b/batch/child.go @@ -35,7 +35,7 @@ func (m *batchManager) addChild(batch *batch, childBatch *batch) error { return fmt.Errorf("addChild: erorr adding parent batch (%s) to child (%s): %v", batch.Id, childBatch.Id, err) } if m.areBatchJobsCompleted(batch) { - m.handleBatchJobsCompleted(batch, map[string]bool{batch.Id: true}) + m.handleBatchJobsCompleted(batch, map[string]bool{batch.Id: true, childBatch.Id: true}) } return nil } @@ -76,6 +76,20 @@ func (m *batchManager) removeParent(batch *batch, parentBatch *batch) error { return nil } +func (m *batchManager) removeChild(batch *batch, childBatch *batch) error { + batch.Meta.ChildCount -= 1 + for i, c := range batch.Children { + if c.Id == childBatch.Id { + batch.Children = append(batch.Children[:i], batch.Children[i+1:]...) + break + } + } + if err := m.rclient.HIncrBy(m.getMetaKey(batch.Id), "child_count", -1).Err(); err != nil { + return fmt.Errorf("handleChildComplete: cannot decrement cihldren_count to batch (%s) %v", batch.Id, err) + } + return nil +} + func (m *batchManager) removeChildren(b *batch) { // locking must be handled outside of function if len(b.Children) > 0 { @@ -96,6 +110,10 @@ func (m *batchManager) handleChildComplete(batch *batch, childBatch *batch, areC if err := m.removeParent(childBatch, batch); err != nil { util.Warnf("childCompleted: unable to remove parent (%s) from (%s): %v", batch.Id, childBatch.Id, err) } + // remove child + if err := m.removeChild(batch, childBatch); err != nil { + util.Warnf("childCompleted: unable to remove child (%s) from (%s): %v", childBatch.Id, batch.Id, err) + } } if m.areBatchJobsCompleted(batch) { m.handleBatchJobsCompleted(batch, parentsVisited) @@ -103,6 +121,9 @@ func (m *batchManager) handleChildComplete(batch *batch, childBatch *batch, areC } func (m *batchManager) areChildrenFinished(b *batch) (bool, bool) { + if len(b.Children) != b.Meta.ChildCount && b.Meta.ChildCount != 0 { + return false, false + } // iterate through children up to a certain depth // check to see if any batch still has jobs being processed currentDepth := 1 diff --git a/batch/commands.go b/batch/commands.go index dfb4e60..8ce70c4 100644 --- a/batch/commands.go +++ b/batch/commands.go @@ -76,13 +76,13 @@ func (b *BatchSubsystem) batchCommand(c *server.Connection, s *server.Server, cm case "OPEN": batchId := parts[1] + b.batchManager.lockBatchIfExists(batchId) + defer b.batchManager.unlockBatchIfExists(batchId) batch, err := b.batchManager.getBatch(batchId) if err != nil { _ = c.Error(cmd, fmt.Errorf("cannot get batch: %v", err)) return } - batch.mu.Lock() - defer batch.mu.Unlock() if b.batchManager.areBatchJobsCompleted(batch) { _ = c.Error(cmd, errors.New("batch has already finished")) @@ -104,13 +104,13 @@ func (b *BatchSubsystem) batchCommand(c *server.Connection, s *server.Server, cm _ = c.Error(cmd, errors.New("bid is required")) return } + b.batchManager.lockBatchIfExists(batchId) + defer b.batchManager.unlockBatchIfExists(batchId) batch, err := b.batchManager.getBatch(batchId) if err != nil { _ = c.Error(cmd, fmt.Errorf("cannot get batch: %v", err)) return } - batch.mu.Lock() - defer batch.mu.Unlock() if err := b.batchManager.commit(batch); err != nil { _ = c.Error(cmd, fmt.Errorf("cannot commit batch: %v", err)) @@ -127,41 +127,55 @@ func (b *BatchSubsystem) batchCommand(c *server.Connection, s *server.Server, cm } batchId := subParts[0] childBatchId := subParts[1] - + if childBatchId == batchId { + _ = c.Error(cmd, fmt.Errorf("child batch and parent batch cannot be the same value")) + return + } + b.batchManager.lockBatchIfExists(batchId) + defer b.batchManager.unlockBatchIfExists(batchId) batch, err := b.batchManager.getBatch(batchId) if err != nil { _ = c.Error(cmd, fmt.Errorf("cannot get batch: %v", err)) return } - batch.mu.Lock() - defer batch.mu.Unlock() opened := false if batch.Meta.Committed { + // open will check if the batch has already finished if err := b.batchManager.open(batch); err != nil { _ = c.Error(cmd, errors.New("cannot open committed batch")) + return } opened = true } - if b.batchManager.areBatchJobsCompleted(batch) { - _ = c.Error(cmd, errors.New("batch has already finished")) - return - } + b.batchManager.lockBatchIfExists(childBatchId) childBatch, err := b.batchManager.getBatch(childBatchId) + // ok is used so the batch can be closed + ok := true if err != nil { _ = c.Error(cmd, fmt.Errorf("cannot get child batch: %v", err)) + ok = false } else if err := b.batchManager.addChild(batch, childBatch); err != nil { _ = c.Error(cmd, fmt.Errorf("cannot add child (%s) to batch (%s): %v", childBatchId, batchId, err)) + ok = false } + // unlock child batch in the off chance it is a transitional parent of batch + b.batchManager.unlockBatchIfExists(childBatchId) + // ensure batch is committed if it was opened if opened { if err := b.batchManager.commit(batch); err != nil { _ = c.Error(cmd, errors.New("cannot commit batch")) + return } } - _ = c.Ok() + if ok { + _ = c.Ok() + } return case "STATUS": batchId := parts[1] + b.batchManager.lockBatchIfExists(batchId) + defer b.batchManager.unlockBatchIfExists(batchId) batch, err := b.batchManager.getBatch(batchId) if err != nil { _ = c.Error(cmd, fmt.Errorf("cannot find batch: %v", err)) diff --git a/batch/middlware.go b/batch/middlware.go index e0ba556..87ef58d 100644 --- a/batch/middlware.go +++ b/batch/middlware.go @@ -9,12 +9,16 @@ import ( func (b *BatchSubsystem) pushMiddleware(next func() error, ctx manager.Context) error { if bid, ok := ctx.Job().GetCustom("bid"); ok { - batch, err := b.batchManager.getBatchFromInterface(bid) + batchId, err := b.batchManager.getBatchIdFromInterface(bid) if err != nil { - return fmt.Errorf("pushMiddleware: unable to get batch %s", bid) + return fmt.Errorf("pushMiddleware: unable to parse batch id %s", bid) + } + b.batchManager.lockBatchIfExists(batchId) + defer b.batchManager.unlockBatchIfExists(batchId) + batch, err := b.batchManager.getBatch(batchId) + if err != nil { + return fmt.Errorf("pushMiddleware: unable to retrieve batch %s", bid) } - batch.mu.Lock() - defer batch.mu.Unlock() if err := b.batchManager.handleJobQueued(batch); err != nil { util.Warnf("unable to add batch %v", err) return fmt.Errorf("pushMiddleware: Unable to add job %s to batch %s", ctx.Job().Jid, bid) @@ -29,36 +33,48 @@ func (b *BatchSubsystem) handleJobFinished(success bool) func(next func() error, if success { // check if this is a success / complete job from batch if bid, ok := ctx.Job().GetCustom("_bid"); ok { - batch, err := b.batchManager.getBatchFromInterface(bid) + + batchId, err := b.batchManager.getBatchIdFromInterface(bid) + if err != nil { + util.Warnf("unable to parse batch id %s", bid) + return next() + } + b.batchManager.lockBatchIfExists(batchId) + defer b.batchManager.unlockBatchIfExists(batchId) + batch, err := b.batchManager.getBatch(batchId) if err != nil { - util.Warnf("Unable to retrieve batch %s: %v", bid, err) + util.Warnf("unable to retrieve batch %s: %v", bid, err) return next() } - batch.mu.Lock() - defer batch.mu.Unlock() cb, ok := ctx.Job().GetCustom("_cb") if !ok { util.Warnf("Batch (%s) callback job (%s) does not have _cb specified", bid, ctx.Job().Type) - return next() + return fmt.Errorf("handleJobFinished: callback job (%s) does not have _cb specified", ctx.Job().Type) } callbackType, ok := cb.(string) if !ok { util.Warnf("Error converting callback job type %s", cb) - return next() + return fmt.Errorf("handleJobFinished: invalid callback type %s", cb) } if err := b.batchManager.handleCallbackJobSucceeded(batch, callbackType); err != nil { util.Warnf("Unable to update batch") + return fmt.Errorf("handleJobFinished: unable to update batch %s", batch.Id) } return next() } } if bid, ok := ctx.Job().GetCustom("bid"); ok { - batch, err := b.batchManager.getBatchFromInterface(bid) + batchId, err := b.batchManager.getBatchIdFromInterface(bid) + if err != nil { + return fmt.Errorf("handleJobFinished: unable to parse batch id %s", bid) + } + b.batchManager.lockBatchIfExists(batchId) + defer b.batchManager.unlockBatchIfExists(batchId) + batch, err := b.batchManager.getBatch(batchId) if err != nil { - return fmt.Errorf("handleJobFinished: unable to retrieve batch %s", bid) + util.Warnf("handleJobFinished: unable to retrieve batch %s", bid) + return next() } - batch.mu.Lock() - defer batch.mu.Unlock() status := "succeeded" if !success { status = "failed" @@ -69,7 +85,7 @@ func (b *BatchSubsystem) handleJobFinished(success bool) func(next func() error, if err := b.batchManager.handleJobFinished(batch, ctx.Job().Jid, success, isRetry); err != nil { util.Warnf("error processing finished job for batch %v", err) - return fmt.Errorf("handleJobFinished: unable to process finished job %s for batch %s", ctx.Job().Jid, batch.Id) + return next() } } return next() diff --git a/batch/stress_test.go b/batch/stress_test.go index e6335ea..b8f6dc0 100644 --- a/batch/stress_test.go +++ b/batch/stress_test.go @@ -78,6 +78,7 @@ func TestBatchStress(t *testing.T) { batchQueue, err := s.Store().GetQueue("batch_load_complete") assert.Nil(t, err) assert.EqualValues(t, waitGroups*total, int(batchQueue.Size())) + close(s.Stopper()) s.Stop(nil) } diff --git a/batch/subsystem.go b/batch/subsystem.go index c4fdd4b..a253d5f 100644 --- a/batch/subsystem.go +++ b/batch/subsystem.go @@ -37,7 +37,7 @@ func (b *BatchSubsystem) Start(s *server.Server) error { b.Server = s b.batchManager = &batchManager{ Batches: make(map[string]*batch), - mu: sync.RWMutex{}, + mu: sync.Mutex{}, rclient: b.Server.Manager().Redis(), Subsystem: b, } @@ -48,7 +48,7 @@ func (b *BatchSubsystem) Start(s *server.Server) error { server.CommandSet["BATCH"] = b.batchCommand b.addMiddleware() - b.Server.AddTask(3600, &removeStaleBatches{b}) + b.Server.AddTask(3600*24, &removeStaleBatches{b}) // once a day util.Info("Loaded batching plugin") return nil }