From 4c34cdd3d365adb4876558de715d579a10f71cfa Mon Sep 17 00:00:00 2001 From: theskyinflames-macos Date: Tue, 20 Dec 2022 20:28:14 +0100 Subject: [PATCH] Add Ch/Qh multi-middleware --- pkg/cqrs/cqrs.go | 28 ++++++++ pkg/cqrs/cqrs_test.go | 80 +++++++++++++++++++++ pkg/cqrs/zmock_cqrs_event_test.go | 112 ++++++++++++++++++++++++++++++ 3 files changed, 220 insertions(+) create mode 100644 pkg/cqrs/zmock_cqrs_event_test.go diff --git a/pkg/cqrs/cqrs.go b/pkg/cqrs/cqrs.go index 05c225e..fcb642b 100644 --- a/pkg/cqrs/cqrs.go +++ b/pkg/cqrs/cqrs.go @@ -10,6 +10,8 @@ import ( "github.com/google/uuid" ) +//go:generate moq -stub -out zmock_cqrs_event_test.go -pkg cqrs_test . Event + // Logger is an interface type Logger interface { Printf(format string, v ...interface{}) @@ -42,6 +44,19 @@ func (chf CommandHandlerFunc) Handle(ctx context.Context, cmd Command) ([]Event, // CommandHandlerMiddleware is self-described type CommandHandlerMiddleware func(CommandHandler) CommandHandler +// CommandHandlerMultiMiddleware is self-described +func CommandHandlerMultiMiddleware(mws ...CommandHandlerMiddleware) CommandHandlerMiddleware { + return func(ch CommandHandler) CommandHandler { + return CommandHandlerFunc(func(ctx context.Context, cmd Command) ([]Event, error) { + mw := mws[0](ch) + for _, outerMw := range mws[1:] { + mw = outerMw(mw) + } + return mw.Handle(ctx, cmd) + }) + } +} + // ChErrMw is a command handler middleware func ChErrMw(l Logger) CommandHandlerMiddleware { return func(ch CommandHandler) CommandHandler { @@ -80,6 +95,19 @@ func (chf QueryHandlerFunc) Handle(ctx context.Context, q Query) (QueryResult, e // QueryHandlerMiddleware is self-described type QueryHandlerMiddleware func(QueryHandler) QueryHandler +// QueryHandlerMultiMiddleware is self-described +func QueryHandlerMultiMiddleware(mws ...QueryHandlerMiddleware) QueryHandlerMiddleware { + return func(ch QueryHandler) QueryHandler { + return QueryHandlerFunc(func(ctx context.Context, cmd Query) (QueryResult, error) { + mw := mws[0](ch) + for _, outerMw := range mws[1:] { + mw = outerMw(mw) + } + return mw.Handle(ctx, cmd) + }) + } +} + // QhErrMw is a query handler middleware func QhErrMw(l Logger) QueryHandlerMiddleware { return func(ch QueryHandler) QueryHandler { diff --git a/pkg/cqrs/cqrs_test.go b/pkg/cqrs/cqrs_test.go index 1a78a80..861864a 100644 --- a/pkg/cqrs/cqrs_test.go +++ b/pkg/cqrs/cqrs_test.go @@ -45,3 +45,83 @@ func TestQhErrMw(t *testing.T) { require.Len(t, logger.PrintfCalls(), 1) }) } + +func TestCommandHandlerMultiMiddleware(t *testing.T) { + t.Run(`Given a sequence of ch middlewares, + when it's called, + then the ch is executed wrapped by all middlewares in the right order`, func(t *testing.T) { + var ( + calls []string + chMw1 = chTestMw("mw1", &calls) + chMw2 = chTestMw("mw2", &calls) + chMw3 = chTestMw("mw3", &calls) + chMw4 = chTestMw("mw4", &calls) + + ev = &EventMock{} + + ch = &CommandHandlerMock{ + HandleFunc: func(_ context.Context, _ cqrs.Command) ([]cqrs.Event, error) { + return []cqrs.Event{ev}, nil + }, + } + ) + + multiChMw := cqrs.CommandHandlerMultiMiddleware(chMw1, chMw2, chMw3, chMw4) + + evs, err := multiChMw(ch).Handle(context.Background(), &CommandMock{}) + + require.Len(t, ch.HandleCalls(), 1) + require.NoError(t, err) + require.Len(t, evs, 1) + require.Equal(t, ev, evs[0]) + require.Equal(t, []string{"mw4", "mw3", "mw2", "mw1"}, calls) + }) +} + +func chTestMw(name string, calls *[]string) cqrs.CommandHandlerMiddleware { + return func(ch cqrs.CommandHandler) cqrs.CommandHandler { + return cqrs.CommandHandlerFunc(func(ctx context.Context, cmd cqrs.Command) ([]cqrs.Event, error) { + *calls = append(*calls, name) + return ch.Handle(ctx, cmd) + }) + } +} + +func TestQueryHandlerMultiMiddleware(t *testing.T) { + t.Run(`Given a sequence of ch middlewares, + when it's called, + then the ch is executed wrapped by all middlewares in the right order`, func(t *testing.T) { + var ( + calls []string + qhMw1 = qhTestMw("mw1", &calls) + qhMw2 = qhTestMw("mw2", &calls) + qhMw3 = qhTestMw("mw3", &calls) + qhMw4 = qhTestMw("mw4", &calls) + + queryResult = "result" + + qh = &QueryHandlerMock{ + HandleFunc: func(_ context.Context, _ cqrs.Query) (cqrs.QueryResult, error) { + return queryResult, nil + }, + } + ) + + multiQhMw := cqrs.QueryHandlerMultiMiddleware(qhMw1, qhMw2, qhMw3, qhMw4) + + qrs, err := multiQhMw(qh).Handle(context.Background(), &QueryMock{}) + + require.Len(t, qh.HandleCalls(), 1) + require.NoError(t, err) + require.Equal(t, queryResult, qrs) + }) +} + +func qhTestMw(name string, calls *[]string) cqrs.QueryHandlerMiddleware { + return func(ch cqrs.QueryHandler) cqrs.QueryHandler { + return cqrs.QueryHandlerFunc(func(ctx context.Context, cmd cqrs.Query) (cqrs.QueryResult, error) { + *calls = append(*calls, name) + return ch.Handle(ctx, cmd) + }) + } +} diff --git a/pkg/cqrs/zmock_cqrs_event_test.go b/pkg/cqrs/zmock_cqrs_event_test.go new file mode 100644 index 0000000..f077ca1 --- /dev/null +++ b/pkg/cqrs/zmock_cqrs_event_test.go @@ -0,0 +1,112 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package cqrs_test + +import ( + "github.com/google/uuid" + "github.com/theskyinflames/cqrs-eda/pkg/cqrs" + "sync" +) + +// Ensure, that EventMock does implement cqrs.Event. +// If this is not the case, regenerate this file with moq. +var _ cqrs.Event = &EventMock{} + +// EventMock is a mock implementation of cqrs.Event. +// +// func TestSomethingThatUsesEvent(t *testing.T) { +// +// // make and configure a mocked cqrs.Event +// mockedEvent := &EventMock{ +// AggregateIDFunc: func() uuid.UUID { +// panic("mock out the AggregateID method") +// }, +// NameFunc: func() string { +// panic("mock out the Name method") +// }, +// } +// +// // use mockedEvent in code that requires cqrs.Event +// // and then make assertions. +// +// } +type EventMock struct { + // AggregateIDFunc mocks the AggregateID method. + AggregateIDFunc func() uuid.UUID + + // NameFunc mocks the Name method. + NameFunc func() string + + // calls tracks calls to the methods. + calls struct { + // AggregateID holds details about calls to the AggregateID method. + AggregateID []struct { + } + // Name holds details about calls to the Name method. + Name []struct { + } + } + lockAggregateID sync.RWMutex + lockName sync.RWMutex +} + +// AggregateID calls AggregateIDFunc. +func (mock *EventMock) AggregateID() uuid.UUID { + callInfo := struct { + }{} + mock.lockAggregateID.Lock() + mock.calls.AggregateID = append(mock.calls.AggregateID, callInfo) + mock.lockAggregateID.Unlock() + if mock.AggregateIDFunc == nil { + var ( + uUIDOut uuid.UUID + ) + return uUIDOut + } + return mock.AggregateIDFunc() +} + +// AggregateIDCalls gets all the calls that were made to AggregateID. +// Check the length with: +// +// len(mockedEvent.AggregateIDCalls()) +func (mock *EventMock) AggregateIDCalls() []struct { +} { + var calls []struct { + } + mock.lockAggregateID.RLock() + calls = mock.calls.AggregateID + mock.lockAggregateID.RUnlock() + return calls +} + +// Name calls NameFunc. +func (mock *EventMock) Name() string { + callInfo := struct { + }{} + mock.lockName.Lock() + mock.calls.Name = append(mock.calls.Name, callInfo) + mock.lockName.Unlock() + if mock.NameFunc == nil { + var ( + sOut string + ) + return sOut + } + return mock.NameFunc() +} + +// NameCalls gets all the calls that were made to Name. +// Check the length with: +// +// len(mockedEvent.NameCalls()) +func (mock *EventMock) NameCalls() []struct { +} { + var calls []struct { + } + mock.lockName.RLock() + calls = mock.calls.Name + mock.lockName.RUnlock() + return calls +}