diff --git a/service/history/handler/handler_test.go b/service/history/handler/handler_test.go index e07a4c5ec97..930a4262a79 100644 --- a/service/history/handler/handler_test.go +++ b/service/history/handler/handler_test.go @@ -2464,6 +2464,358 @@ func (s *handlerSuite) TestRecordChildExecutionCompleted() { } } +func (s *handlerSuite) TestResetStickyTaskList() { + validInput := &types.HistoryResetStickyTaskListRequest{ + DomainUUID: testDomainID, + Execution: &types.WorkflowExecution{ + WorkflowID: testWorkflowID, + RunID: testValidUUID, + }, + } + + testInput := map[string]struct { + input *types.HistoryResetStickyTaskListRequest + expectedError bool + mockFn func() + }{ + "shutting down": { + input: validInput, + expectedError: true, + mockFn: func() { + s.handler.shuttingDown = int32(1) + }, + }, + "empty domainID": { + input: &types.HistoryResetStickyTaskListRequest{ + DomainUUID: "", + }, + expectedError: true, + mockFn: func() {}, + }, + "ratelimit exceeded": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(false).Times(1) + }, + }, + "getEngine error": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(nil, errors.New("error")).Times(1) + }, + }, + "resetStickyTaskList error": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().ResetStickyTaskList(gomock.Any(), validInput).Return(nil, errors.New("error")).Times(1) + }, + }, + "success": { + input: validInput, + expectedError: false, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().ResetStickyTaskList(gomock.Any(), validInput).Return(&types.HistoryResetStickyTaskListResponse{}, nil).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + resp, err := s.handler.ResetStickyTaskList(context.Background(), input.input) + s.handler.shuttingDown = int32(0) + if input.expectedError { + s.Nil(resp) + s.Error(err) + } else { + s.NotNil(resp) + s.NoError(err) + } + }) + + } +} + +func (s *handlerSuite) TestReplicateEventsV2() { + validInput := &types.ReplicateEventsV2Request{ + DomainUUID: testDomainID, + WorkflowExecution: &types.WorkflowExecution{ + WorkflowID: testWorkflowID, + RunID: testValidUUID, + }, + VersionHistoryItems: []*types.VersionHistoryItem{ + { + EventID: 1, + Version: 1, + }, + }, + Events: &types.DataBlob{ + EncodingType: types.EncodingTypeThriftRW.Ptr(), + Data: []byte{1, 2, 3}, + }, + } + + testInput := map[string]struct { + input *types.ReplicateEventsV2Request + expectedError bool + mockFn func() + }{ + "shutting down": { + input: validInput, + expectedError: true, + mockFn: func() { + s.handler.shuttingDown = int32(1) + }, + }, + "empty domainID": { + input: &types.ReplicateEventsV2Request{ + DomainUUID: "", + }, + expectedError: true, + mockFn: func() {}, + }, + "ratelimit exceeded": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(false).Times(1) + }, + }, + "getEngine error": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(nil, errors.New("error")).Times(1) + }, + }, + "replicateEventsV2 error": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().ReplicateEventsV2(gomock.Any(), validInput).Return(errors.New("error")).Times(1) + }, + }, + "success": { + input: validInput, + expectedError: false, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().ReplicateEventsV2(gomock.Any(), validInput).Return(nil).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + err := s.handler.ReplicateEventsV2(context.Background(), input.input) + s.handler.shuttingDown = int32(0) + if input.expectedError { + s.Error(err) + } else { + s.NoError(err) + } + }) + } +} + +func (s *handlerSuite) TestSyncShardStatus() { + validInput := &types.SyncShardStatusRequest{ + SourceCluster: "test", + ShardID: 1, + Timestamp: common.Int64Ptr(time.Now().UnixNano()), + } + + testInput := map[string]struct { + input *types.SyncShardStatusRequest + expectedError bool + mockFn func() + }{ + "shutting down": { + input: validInput, + expectedError: true, + mockFn: func() { + s.handler.shuttingDown = int32(1) + }, + }, + "ratelimit exceeded": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(false).Times(1) + }, + }, + "get shard engine": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngineForShard(int(validInput.ShardID)).Return(nil, errors.New("error")).Times(1) + }, + }, + "syncShardStatus error": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngineForShard(int(validInput.ShardID)).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().SyncShardStatus(gomock.Any(), validInput).Return(errors.New("error")).Times(1) + }, + }, + "success": { + input: validInput, + expectedError: false, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngineForShard(int(validInput.ShardID)).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().SyncShardStatus(gomock.Any(), validInput).Return(nil).Times(1) + }, + }, + "empty sourceCluster": { + input: &types.SyncShardStatusRequest{ + SourceCluster: "", + }, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + }, + }, + "missing timestamp": { + input: &types.SyncShardStatusRequest{ + SourceCluster: "test", + }, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + err := s.handler.SyncShardStatus(context.Background(), input.input) + s.handler.shuttingDown = int32(0) + if input.expectedError { + s.Error(err) + } else { + s.NoError(err) + } + }) + } +} + +func (s *handlerSuite) TestSyncActivity() { + validInput := &types.SyncActivityRequest{ + DomainID: testDomainID, + WorkflowID: testWorkflowID, + RunID: testValidUUID, + Version: 1, + ScheduledID: 1, + Details: []byte{1, 2, 3}, + } + + testInput := map[string]struct { + input *types.SyncActivityRequest + expectedError bool + mockFn func() + }{ + "shutting down": { + input: validInput, + expectedError: true, + mockFn: func() { + s.handler.shuttingDown = int32(1) + }, + }, + "ratelimit exceeded": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(false).Times(1) + }, + }, + "empty domainID": { + input: &types.SyncActivityRequest{ + DomainID: "", + }, + expectedError: true, + mockFn: func() {}, + }, + "empty workflowID": { + input: &types.SyncActivityRequest{ + DomainID: testDomainID, + WorkflowID: "", + }, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + }, + }, + "empty runID": { + input: &types.SyncActivityRequest{ + DomainID: testDomainID, + WorkflowID: testWorkflowID, + RunID: "", + }, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + }, + }, + "cannot get engine": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(nil, errors.New("error")).Times(1) + }, + }, + "syncActivity error": { + input: validInput, + expectedError: true, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().SyncActivity(gomock.Any(), validInput).Return(errors.New("error")).Times(1) + }, + }, + "success": { + input: validInput, + expectedError: false, + mockFn: func() { + s.mockRatelimiter.EXPECT().Allow().Return(true).Times(1) + s.mockShardController.EXPECT().GetEngine(testWorkflowID).Return(s.mockEngine, nil).Times(1) + s.mockEngine.EXPECT().SyncActivity(gomock.Any(), validInput).Return(nil).Times(1) + }, + }, + } + + for name, input := range testInput { + s.Run(name, func() { + input.mockFn() + err := s.handler.SyncActivity(context.Background(), input.input) + s.handler.shuttingDown = int32(0) + if input.expectedError { + s.Error(err) + } else { + s.NoError(err) + } + }) + } +} + func (s *handlerSuite) TestGetCrossClusterTasks() { numShards := 10 targetCluster := cluster.TestAlternativeClusterName