From 6d13781f9c28390811a1d339f904cc17b62293b9 Mon Sep 17 00:00:00 2001 From: Will Winder Date: Wed, 7 Aug 2024 14:55:15 -0400 Subject: [PATCH] exec: Add in plugin state and start using it. --- execute/exectypes/outcome.go | 68 ++++++++-- execute/exectypes/outcome_test.go | 53 ++++++++ execute/plugin.go | 202 ++++++++++++++++++------------ execute/plugin_test.go | 2 +- 4 files changed, 237 insertions(+), 88 deletions(-) create mode 100644 execute/exectypes/outcome_test.go diff --git a/execute/exectypes/outcome.go b/execute/exectypes/outcome.go index 6c39ad10..20dda005 100644 --- a/execute/exectypes/outcome.go +++ b/execute/exectypes/outcome.go @@ -7,8 +7,47 @@ import ( cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccipocr3" ) +type PluginState string + +const ( + // Unknown is the zero value, this allows a "Next" state transition for uninitialized values (i.e. the first round). + Unknown PluginState = "" + + // GetCommitReports is the first step, it is used to select commit reports from the destination chain. + GetCommitReports PluginState = "GetCommitReports" + + // GetMessages is the second step, given a set of commit reports it fetches the associated messages. + GetMessages PluginState = "GetMessages" + + // Filter is the final step, any additional destination data is collected to complete the execution report. + Filter PluginState = "Filter" +) + +// Next returns the next state for the plugin. The Unknown state is used to transition from uninitialized values. +func (p PluginState) Next() PluginState { + switch p { + case GetCommitReports: + return GetMessages + + case GetMessages: + // TODO: go to Filter after GetMessages + return GetCommitReports + + case Unknown: + fallthrough + case Filter: + return GetCommitReports + + default: + panic("unexpected execute plugin state") + } +} + // Outcome is the outcome of the ExecutePlugin. type Outcome struct { + // State that the outcome was generated for. + State PluginState + // PendingCommitReports are the oldest reports with pending commits. The slice is // sorted from oldest to newest. PendingCommitReports []CommitData `json:"commitReports"` @@ -17,28 +56,27 @@ type Outcome struct { Report cciptypes.ExecutePluginReport `json:"report"` } +// IsEmpty returns true if the outcome has no pending commit reports or chain reports. func (o Outcome) IsEmpty() bool { return len(o.PendingCommitReports) == 0 && len(o.Report.ChainReports) == 0 } +// NewOutcome creates a new Outcome with the pending commit reports and the chain reports sorted. func NewOutcome( + state PluginState, pendingCommits []CommitData, report cciptypes.ExecutePluginReport, ) Outcome { - return newSortedOutcome(pendingCommits, report) -} - -// Encode encodes the outcome by first sorting the pending commit reports and the chain reports -// and then JSON marshalling. -// The encoding MUST be deterministic. -func (o Outcome) Encode() ([]byte, error) { - // We sort again here in case construction is not via the constructor. - return json.Marshal(newSortedOutcome(o.PendingCommitReports, o.Report)) + return newSortedOutcome(state, pendingCommits, report) } +// newSortedOutcome ensures canonical ordering of the outcome. +// TODO: handle canonicalization in the encoder. func newSortedOutcome( + state PluginState, pendingCommits []CommitData, - report cciptypes.ExecutePluginReport) Outcome { + report cciptypes.ExecutePluginReport, +) Outcome { pendingCommitsCP := append([]CommitData{}, pendingCommits...) reportCP := append([]cciptypes.ExecutePluginReportSingleChain{}, report.ChainReports...) sort.Slice( @@ -52,11 +90,21 @@ func newSortedOutcome( return reportCP[i].SourceChainSelector < reportCP[j].SourceChainSelector }) return Outcome{ + State: state, PendingCommitReports: pendingCommitsCP, Report: cciptypes.ExecutePluginReport{ChainReports: reportCP}, } } +// Encode encodes the outcome by first sorting the pending commit reports and the chain reports +// and then JSON marshalling. +// The encoding MUST be deterministic. +func (o Outcome) Encode() ([]byte, error) { + // We sort again here in case construction is not via the constructor. + return json.Marshal(newSortedOutcome(o.State, o.PendingCommitReports, o.Report)) +} + +// DecodeOutcome decodes the outcome from JSON. func DecodeOutcome(b []byte) (Outcome, error) { o := Outcome{} err := json.Unmarshal(b, &o) diff --git a/execute/exectypes/outcome_test.go b/execute/exectypes/outcome_test.go new file mode 100644 index 00000000..35dd376b --- /dev/null +++ b/execute/exectypes/outcome_test.go @@ -0,0 +1,53 @@ +package exectypes + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPluginState_Next(t *testing.T) { + tests := []struct { + name string + p PluginState + want PluginState + isPanic bool + }{ + { + name: "Zero value", + p: Unknown, + want: GetCommitReports, + }, + { + name: "Phase 1 to 2", + p: GetCommitReports, + want: GetMessages, + }, + { + name: "Phase 2 to 1", + p: GetMessages, + want: GetCommitReports, + }, + { + name: "panic", + p: PluginState("ElToroLoco"), + isPanic: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if tt.isPanic { + require.Panics(t, func() { + tt.p.Next() + }) + return + } + + if got := tt.p.Next(); got != tt.want { + t.Errorf("Next() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/execute/plugin.go b/execute/plugin.go index c4609f74..7a75fcf3 100644 --- a/execute/plugin.go +++ b/execute/plugin.go @@ -164,69 +164,90 @@ func (p *Plugin) Observation( if err != nil { return types.Observation{}, fmt.Errorf("unable to decode previous outcome: %w", err) } + p.lggr.Infow("decoded previous outcome", "previousOutcome", previousOutcome) } - fetchFrom := time.Now().Add(-p.cfg.OffchainConfig.MessageVisibilityInterval.Duration()).UTC() - p.lggr.Infow("decoded previous outcome", "previousOutcome", previousOutcome) + state := previousOutcome.State.Next() + switch state { + case exectypes.GetCommitReports: + fetchFrom := time.Now().Add(-p.cfg.OffchainConfig.MessageVisibilityInterval.Duration()).UTC() - // Phase 1: Gather commit reports from the destination chain and determine which messages are required to build a - // valid execution report. - var groupedCommits exectypes.CommitObservations - supportsDest, err := p.supportsDestChain() - if err != nil { - return types.Observation{}, fmt.Errorf("unable to determine if the destination chain is supported: %w", err) - } - if supportsDest { - groupedCommits, err = getPendingExecutedReports(ctx, p.ccipReader, p.cfg.DestChain, fetchFrom, p.lggr) + // Phase 1: Gather commit reports from the destination chain and determine which messages are required to build + // a valid execution report. + supportsDest, err := p.supportsDestChain() if err != nil { - return types.Observation{}, err + return types.Observation{}, fmt.Errorf("unable to determine if the destination chain is supported: %w", err) } + if supportsDest { + groupedCommits, err := getPendingExecutedReports(ctx, p.ccipReader, p.cfg.DestChain, fetchFrom, p.lggr) + if err != nil { + return types.Observation{}, err + } - // TODO: truncate grouped commits to a maximum observation size. - // Cache everything which is not executed. - } - - // Phase 2: Gather messages from the source chains and build the execution report. - messages := make(exectypes.MessageObservations) - if len(previousOutcome.PendingCommitReports) == 0 { - p.lggr.Debug("TODO: No reports to execute. This is expected after a cold start.") - // No reports to execute. - // This is expected after a cold start. - } else { - commitReportCache := make(map[cciptypes.ChainSelector][]exectypes.CommitData) - for _, report := range previousOutcome.PendingCommitReports { - commitReportCache[report.SourceChain] = append(commitReportCache[report.SourceChain], report) + // TODO: truncate grouped to a maximum observation size? + return exectypes.NewObservation(groupedCommits, nil).Encode() } - for selector, reports := range commitReportCache { - if len(reports) == 0 { - continue + // No observation for non-dest readers. + return types.Observation{}, nil + case exectypes.GetMessages: + // Phase 2: Gather messages from the source chains and build the execution report. + messages := make(exectypes.MessageObservations) + if len(previousOutcome.PendingCommitReports) == 0 { + p.lggr.Debug("TODO: No reports to execute. This is expected after a cold start.") + // No reports to execute. + // This is expected after a cold start. + } else { + commitReportCache := make(map[cciptypes.ChainSelector][]exectypes.CommitData) + for _, report := range previousOutcome.PendingCommitReports { + commitReportCache[report.SourceChain] = append(commitReportCache[report.SourceChain], report) } - ranges, err := computeRanges(reports) - if err != nil { - return types.Observation{}, err - } + for selector, reports := range commitReportCache { + if len(reports) == 0 { + continue + } - // Read messages for each range. - for _, seqRange := range ranges { - msgs, err := p.ccipReader.MsgsBetweenSeqNums(ctx, selector, seqRange) + ranges, err := computeRanges(reports) if err != nil { - return nil, err + return types.Observation{}, err } - for _, msg := range msgs { - if _, ok := messages[selector]; !ok { - messages[selector] = make(map[cciptypes.SeqNum]cciptypes.Message) + + // Read messages for each range. + for _, seqRange := range ranges { + msgs, err := p.ccipReader.MsgsBetweenSeqNums(ctx, selector, seqRange) + if err != nil { + return nil, err + } + for _, msg := range msgs { + if _, ok := messages[selector]; !ok { + messages[selector] = make(map[cciptypes.SeqNum]cciptypes.Message) + } + messages[selector][msg.Header.SequenceNumber] = msg } - messages[selector][msg.Header.SequenceNumber] = msg } } } - } - // TODO: Fire off messages for an attestation check service. + // Regroup the commit reports back into the observation format. + // TODO: use same format for Observation and Outcome. + groupedCommits := make(exectypes.CommitObservations) + for _, report := range previousOutcome.PendingCommitReports { + if _, ok := groupedCommits[report.SourceChain]; !ok { + groupedCommits[report.SourceChain] = []exectypes.CommitData{} + } + groupedCommits[report.SourceChain] = append(groupedCommits[report.SourceChain], report) + } + + // TODO: Fire off messages for an attestation check service. + return exectypes.NewObservation(groupedCommits, messages).Encode() - return exectypes.NewObservation(groupedCommits, messages).Encode() + case exectypes.Filter: + // TODO: pass the previous two through, add in the nonces. + return types.Observation{}, fmt.Errorf("unknown state") + default: + return types.Observation{}, fmt.Errorf("unknown state") + } } func (p *Plugin) ValidateObservation( @@ -320,6 +341,18 @@ func selectReport( func (p *Plugin) Outcome( outctx ocr3types.OutcomeContext, query types.Query, aos []types.AttributedObservation, ) (ocr3types.Outcome, error) { + var previousOutcome exectypes.Outcome + if outctx.PreviousOutcome != nil { + var err error + previousOutcome, err = exectypes.DecodeOutcome(outctx.PreviousOutcome) + if err != nil { + return nil, fmt.Errorf("unable to decode previous outcome: %w", err) + } + } + + ///////////////////////////////////////////// + // Decode the observations and merge them. // + ///////////////////////////////////////////// decodedObservations, err := decodeAttributedObservations(aos) if err != nil { return ocr3types.Outcome{}, fmt.Errorf("unable to decode observations: %w", err) @@ -359,6 +392,10 @@ func (p *Plugin) Outcome( mergedCommitObservations, mergedMessageObservations) + ////////////////////////// + // common preprocessing // + ////////////////////////// + // flatten commit reports and sort by timestamp. var commitReports []exectypes.CommitData for _, report := range observation.CommitReports { @@ -372,46 +409,57 @@ func (p *Plugin) Outcome( fmt.Sprintf("[oracle %d] exec outcome: commit reports", p.reportingCfg.OracleID), "commitReports", commitReports) - // add messages to their commitReports. - for i, report := range commitReports { - report.Messages = nil - for i := report.SequenceNumberRange.Start(); i <= report.SequenceNumberRange.End(); i++ { - if msg, ok := observation.Messages[report.SourceChain][i]; ok { - report.Messages = append(report.Messages, msg) + state := previousOutcome.State.Next() + switch state { + case exectypes.GetCommitReports: + outcome := exectypes.NewOutcome(state, commitReports, cciptypes.ExecutePluginReport{}) + return outcome.Encode() + case exectypes.GetMessages: + // add messages to their commitReports. + for i, report := range commitReports { + report.Messages = nil + for i := report.SequenceNumberRange.Start(); i <= report.SequenceNumberRange.End(); i++ { + if msg, ok := observation.Messages[report.SourceChain][i]; ok { + report.Messages = append(report.Messages, msg) + } } + commitReports[i].Messages = report.Messages } - commitReports[i].Messages = report.Messages - } - // TODO: this function should be pure, a context should not be needed. - outcomeReports, commitReports, err := - selectReport( - context.Background(), - p.lggr, p.msgHasher, - p.reportCodec, - p.tokenDataReader, - p.estimateProvider, - commitReports, - maxReportSizeBytes, - p.cfg.OffchainConfig.BatchGasLimit) - if err != nil { - return ocr3types.Outcome{}, fmt.Errorf("unable to extract proofs: %w", err) - } + // TODO: this function should be pure, a context should not be needed. + outcomeReports, commitReports, err := + selectReport( + context.Background(), + p.lggr, p.msgHasher, + p.reportCodec, + p.tokenDataReader, + p.estimateProvider, + commitReports, + maxReportSizeBytes, + p.cfg.OffchainConfig.BatchGasLimit) + if err != nil { + return ocr3types.Outcome{}, fmt.Errorf("unable to extract proofs: %w", err) + } - execReport := cciptypes.ExecutePluginReport{ - ChainReports: outcomeReports, - } + execReport := cciptypes.ExecutePluginReport{ + ChainReports: outcomeReports, + } - outcome := exectypes.NewOutcome(commitReports, execReport) - if outcome.IsEmpty() { - return nil, nil - } + outcome := exectypes.NewOutcome(state, commitReports, execReport) + if outcome.IsEmpty() { + return nil, nil + } - p.lggr.Infow( - fmt.Sprintf("[oracle %d] exec outcome: generated outcome", p.reportingCfg.OracleID), - "outcome", outcome) + p.lggr.Infow( + fmt.Sprintf("[oracle %d] exec outcome: generated outcome", p.reportingCfg.OracleID), + "outcome", outcome) - return outcome.Encode() + return outcome.Encode() + case exectypes.Filter: + panic("not implemented") + default: + panic("unknown state") + } } func (p *Plugin) Reports(seqNr uint64, outcome ocr3types.Outcome) ([]ocr3types.ReportWithInfo[[]byte], error) { diff --git a/execute/plugin_test.go b/execute/plugin_test.go index 2b0fe365..3119940c 100644 --- a/execute/plugin_test.go +++ b/execute/plugin_test.go @@ -406,7 +406,7 @@ func TestPlugin_Reports_UnableToEncode(t *testing.T) { codec.On("Encode", mock.Anything, mock.Anything). Return(nil, fmt.Errorf("test error")) p := &Plugin{reportCodec: codec} - report, err := exectypes.NewOutcome(nil, cciptypes.ExecutePluginReport{}).Encode() + report, err := exectypes.NewOutcome(exectypes.Unknown, nil, cciptypes.ExecutePluginReport{}).Encode() require.NoError(t, err) _, err = p.Reports(0, report)