Skip to content

Commit

Permalink
addressing review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
3vilhamster committed Nov 1, 2024
1 parent a85110d commit 71dd6cb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
27 changes: 15 additions & 12 deletions internal/internal_task_pollers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"context"
"errors"
"fmt"
"go.uber.org/yarpc"
"sync"
"time"

Expand Down Expand Up @@ -356,7 +357,7 @@ func (wtp *workflowTaskPoller) processWorkflowTask(task *workflowTask) error {
func(response interface{}, startTime time.Time) (*workflowTask, error) {
wtp.logger.Debug("Force RespondDecisionTaskCompleted.", zap.Int64("TaskStartedEventID", task.task.GetStartedEventId()))
wtp.metricsScope.Counter(metrics.DecisionTaskForceCompleted).Inc(1)
heartbeatResponse, err := wtp.RespondTaskCompleted(response, nil, task.task, startTime)
heartbeatResponse, err := wtp.RespondTaskCompletedWithMetrics(response, nil, task.task, startTime)
if err != nil {
return nil, err
}
Expand All @@ -375,7 +376,7 @@ func (wtp *workflowTaskPoller) processWorkflowTask(task *workflowTask) error {
if errors.As(err, new(*decisionHeartbeatError)) {
return err
}
response, err = wtp.RespondTaskCompleted(completedRequest, err, task.task, startTime)
response, err = wtp.RespondTaskCompletedWithMetrics(completedRequest, err, task.task, startTime)
if err != nil {
return err
}
Expand Down Expand Up @@ -404,7 +405,7 @@ func (wtp *workflowTaskPoller) processResetStickinessTask(rst *resetStickinessTa
return nil
}

func (wtp *workflowTaskPoller) RespondTaskCompleted(completedRequest interface{}, taskErr error, task *s.PollForDecisionTaskResponse, startTime time.Time) (response *s.RespondDecisionTaskCompletedResponse, err error) {
func (wtp *workflowTaskPoller) RespondTaskCompletedWithMetrics(completedRequest interface{}, taskErr error, task *s.PollForDecisionTaskResponse, startTime time.Time) (response *s.RespondDecisionTaskCompletedResponse, err error) {
metricsScope := wtp.metricsScope.GetTaggedScope(tagWorkflowType, task.WorkflowType.GetName())
if taskErr != nil {
metricsScope.Counter(metrics.DecisionExecutionFailedCounter).Inc(1)
Expand Down Expand Up @@ -444,7 +445,7 @@ func (wtp *workflowTaskPoller) respondTaskCompleted(completedRequest interface{}
}

func (wtp *workflowTaskPoller) respondTaskCompletedAttempt(completedRequest interface{}, task *s.PollForDecisionTaskResponse) (*s.RespondDecisionTaskCompletedResponse, error) {
ctx, cancel, _ := newChannelContext(context.Background(), wtp.featureFlags)
ctx, cancel, opts := newChannelContext(context.Background(), wtp.featureFlags)
defer cancel()
var (
err error
Expand All @@ -453,36 +454,38 @@ func (wtp *workflowTaskPoller) respondTaskCompletedAttempt(completedRequest inte
)
switch request := completedRequest.(type) {
case *s.RespondDecisionTaskFailedRequest:
err = wtp.handleDecisionFailedRequest(ctx, task, request)
err = wtp.handleDecisionFailedRequest(ctx, task, request, opts...)
operation = "RespondDecisionTaskFailed"
case *s.RespondDecisionTaskCompletedRequest:
response, err = wtp.handleDecisionTaskCompletedRequest(ctx, task, request)
response, err = wtp.handleDecisionTaskCompletedRequest(ctx, task, request, opts...)
operation = "RespondDecisionTaskCompleted"
case *s.RespondQueryTaskCompletedRequest:
err = wtp.service.RespondQueryTaskCompleted(ctx, request, getYarpcCallOptions(wtp.featureFlags)...)
err = wtp.service.RespondQueryTaskCompleted(ctx, request, opts...)
operation = "RespondQueryTaskCompleted"
default:
// should not happen
panic("unknown request type from ProcessWorkflowTask()")
}

traceLog(func() {
wtp.logger.Debug("Call failed.", zap.Error(err), zap.String("Operation", operation))
if err != nil {
wtp.logger.Debug(fmt.Sprintf("%s failed.", operation), zap.Error(err))
}
})

return response, err
}

func (wtp *workflowTaskPoller) handleDecisionFailedRequest(ctx context.Context, task *s.PollForDecisionTaskResponse, request *s.RespondDecisionTaskFailedRequest) error {
func (wtp *workflowTaskPoller) handleDecisionFailedRequest(ctx context.Context, task *s.PollForDecisionTaskResponse, request *s.RespondDecisionTaskFailedRequest, opts ...yarpc.CallOption) error {
// Only fail decision on first attempt, subsequent failure on the same decision task will timeout.
// This is to avoid spin on the failed decision task. Checking Attempt not nil for older server.
if task.Attempt != nil && task.GetAttempt() == 0 {
return wtp.service.RespondDecisionTaskFailed(ctx, request, getYarpcCallOptions(wtp.featureFlags)...)
return wtp.service.RespondDecisionTaskFailed(ctx, request, opts...)
}
return nil
}

func (wtp *workflowTaskPoller) handleDecisionTaskCompletedRequest(ctx context.Context, task *s.PollForDecisionTaskResponse, request *s.RespondDecisionTaskCompletedRequest) (response *s.RespondDecisionTaskCompletedResponse, err error) {
func (wtp *workflowTaskPoller) handleDecisionTaskCompletedRequest(ctx context.Context, task *s.PollForDecisionTaskResponse, request *s.RespondDecisionTaskCompletedRequest, opts ...yarpc.CallOption) (response *s.RespondDecisionTaskCompletedResponse, err error) {
if request.StickyAttributes == nil && !wtp.disableStickyExecution {
request.StickyAttributes = &s.StickyExecutionAttributes{
WorkerTaskList: &s.TaskList{Name: common.StringPtr(getWorkerTaskList(wtp.stickyUUID))},
Expand Down Expand Up @@ -538,7 +541,7 @@ func (wtp *workflowTaskPoller) handleDecisionTaskCompletedRequest(ctx context.Co
}()
}

return wtp.service.RespondDecisionTaskCompleted(ctx, request, getYarpcCallOptions(wtp.featureFlags)...)
return wtp.service.RespondDecisionTaskCompleted(ctx, request, opts...)
}

func newLocalActivityPoller(params workerExecutionParameters, laTunnel *localActivityTunnel) *localActivityTaskPoller {
Expand Down
10 changes: 5 additions & 5 deletions internal/internal_task_pollers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestRespondTaskCompleted_failed(t *testing.T) {
BinaryChecksum: common.StringPtr(getBinaryChecksum()),
}, gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)

res, err := poller.RespondTaskCompleted(nil, assert.AnError, &s.PollForDecisionTaskResponse{
res, err := poller.RespondTaskCompletedWithMetrics(nil, assert.AnError, &s.PollForDecisionTaskResponse{
TaskToken: testTaskToken,
Attempt: common.Int64Ptr(0),
}, time.Now())
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestRespondTaskCompleted_failed(t *testing.T) {
t.Run("fail skips sending for not the first attempt", func(t *testing.T) {
poller, _, _, _ := buildWorkflowTaskPoller(t)

res, err := poller.RespondTaskCompleted(nil, assert.AnError, &s.PollForDecisionTaskResponse{
res, err := poller.RespondTaskCompletedWithMetrics(nil, assert.AnError, &s.PollForDecisionTaskResponse{
Attempt: common.Int64Ptr(1),
}, time.Now())
assert.NoError(t, err)
Expand All @@ -122,8 +122,8 @@ func TestRespondTaskCompleted_failed(t *testing.T) {
func TestRespondTaskCompleted_Unsupported(t *testing.T) {
poller, _, _, _ := buildWorkflowTaskPoller(t)

assert.Panics(t, func() {
_, _ = poller.RespondTaskCompleted(assert.AnError, nil, &s.PollForDecisionTaskResponse{}, time.Now())
assert.PanicsWithValue(t, "unknown request type from ProcessWorkflowTask()", func() {
_, _ = poller.RespondTaskCompletedWithMetrics(assert.AnError, nil, &s.PollForDecisionTaskResponse{}, time.Now())
})
}

Expand All @@ -139,7 +139,7 @@ func TestProcessTask_failures(t *testing.T) {
})
t.Run("unsupported task type", func(t *testing.T) {
poller, _, _, _ := buildWorkflowTaskPoller(t)
assert.Panics(t, func() {
assert.PanicsWithValue(t, "unknown task type.", func() {
_ = poller.ProcessTask(10)
})
})
Expand Down

0 comments on commit 71dd6cb

Please sign in to comment.