diff --git a/mqtt/handler_test.go b/mqtt/handler_test.go index 8fe2b384e73..7ecce2b9649 100644 --- a/mqtt/handler_test.go +++ b/mqtt/handler_test.go @@ -64,7 +64,7 @@ var ( ) func TestAuthConnect(t *testing.T) { - handler, _ := newHandler(t) + handler, _, eventStore := newHandler(t) cases := []struct { desc string @@ -107,17 +107,19 @@ func TestAuthConnect(t *testing.T) { } for _, tc := range cases { + repoCall := eventStore.On("Connect", mock.Anything, mock.Anything).Return(nil) ctx := context.TODO() if tc.session != nil { ctx = session.NewContext(ctx, tc.session) } err := handler.AuthConnect(ctx) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() } } func TestAuthPublish(t *testing.T) { - handler, auth := newHandler(t) + handler, auth, _ := newHandler(t) cases := []struct { desc string @@ -169,7 +171,7 @@ func TestAuthPublish(t *testing.T) { } func TestAuthSubscribe(t *testing.T) { - handler, auth := newHandler(t) + handler, auth, _ := newHandler(t) cases := []struct { desc string @@ -222,7 +224,7 @@ func TestAuthSubscribe(t *testing.T) { } func TestConnect(t *testing.T) { - handler, _ := newHandler(t) + handler, _, _ := newHandler(t) logBuffer.Reset() cases := []struct { @@ -256,7 +258,7 @@ func TestConnect(t *testing.T) { } func TestPublish(t *testing.T) { - handler, _ := newHandler(t) + handler, _, _ := newHandler(t) logBuffer.Reset() malformedSubtopics := topic + "/" + subtopic + "%" @@ -335,7 +337,7 @@ func TestPublish(t *testing.T) { } func TestSubscribe(t *testing.T) { - handler, _ := newHandler(t) + handler, _, _ := newHandler(t) logBuffer.Reset() cases := []struct { @@ -371,7 +373,7 @@ func TestSubscribe(t *testing.T) { } func TestUnsubscribe(t *testing.T) { - handler, _ := newHandler(t) + handler, _, _ := newHandler(t) logBuffer.Reset() cases := []struct { @@ -407,7 +409,7 @@ func TestUnsubscribe(t *testing.T) { } func TestDisconnect(t *testing.T) { - handler, _ := newHandler(t) + handler, _, eventStore := newHandler(t) logBuffer.Reset() cases := []struct { @@ -432,6 +434,7 @@ func TestDisconnect(t *testing.T) { } for _, tc := range cases { + repoCall := eventStore.On("Disconnect", mock.Anything, mock.Anything).Return(nil) ctx := context.TODO() if tc.session != nil { ctx = session.NewContext(ctx, tc.session) @@ -439,15 +442,16 @@ func TestDisconnect(t *testing.T) { err := handler.Disconnect(ctx) assert.Contains(t, logBuffer.String(), tc.logMsg) assert.Equal(t, tc.err, err) + repoCall.Unset() } } -func newHandler(t *testing.T) (session.Handler, *authmocks.AuthClient) { +func newHandler(t *testing.T) (session.Handler, *authmocks.AuthClient, *mocks.EventStore) { logger, err := mglog.New(&logBuffer, "debug") if err != nil { log.Fatalf("failed to create logger: %s", err) } auth := new(authmocks.AuthClient) eventStore := mocks.NewEventStore(t) - return mqtt.NewHandler(mocks.NewPublisher(), eventStore, logger, auth), auth + return mqtt.NewHandler(mocks.NewPublisher(), eventStore, logger, auth), auth, eventStore }