Skip to content

Commit

Permalink
fix: tests race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
rolznz committed Aug 20, 2024
1 parent e26e304 commit 7c35191
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 20 deletions.
13 changes: 10 additions & 3 deletions tests/mock_event_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@ package tests

import (
"context"
"time"

"github.com/getAlby/hub/events"
)

type mockEventConsumer struct {
ConsumedEvents []*events.Event
consumedEvents []*events.Event
}

func NewMockEventConsumer() *mockEventConsumer {
return &mockEventConsumer{
ConsumedEvents: []*events.Event{},
consumedEvents: []*events.Event{},
}
}

func (e *mockEventConsumer) ConsumeEvent(ctx context.Context, event *events.Event, globalProperties map[string]interface{}) {
e.ConsumedEvents = append(e.ConsumedEvents, event)
e.consumedEvents = append(e.consumedEvents, event)
}

func (e *mockEventConsumer) GetConsumeEvents() []*events.Event {
// events are consumed async - give it a bit of time for tests
time.Sleep(1 * time.Millisecond)
return e.consumedEvents
}
10 changes: 5 additions & 5 deletions transactions/app_payments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ func TestSendPaymentSync_App_BudgetExceeded(t *testing.T) {
assert.ErrorIs(t, err, NewQuotaExceededError())
assert.Nil(t, transaction)

assert.Equal(t, 1, len(mockEventConsumer.ConsumedEvents))
assert.Equal(t, "nwc_permission_denied", mockEventConsumer.ConsumedEvents[0].Event)
assert.Equal(t, app.Name, mockEventConsumer.ConsumedEvents[0].Properties.(map[string]interface{})["app_name"])
assert.Equal(t, constants.ERROR_QUOTA_EXCEEDED, mockEventConsumer.ConsumedEvents[0].Properties.(map[string]interface{})["code"])
assert.Equal(t, NewQuotaExceededError().Error(), mockEventConsumer.ConsumedEvents[0].Properties.(map[string]interface{})["message"])
assert.Equal(t, 1, len(mockEventConsumer.GetConsumeEvents()))
assert.Equal(t, "nwc_permission_denied", mockEventConsumer.GetConsumeEvents()[0].Event)
assert.Equal(t, app.Name, mockEventConsumer.GetConsumeEvents()[0].Properties.(map[string]interface{})["app_name"])
assert.Equal(t, constants.ERROR_QUOTA_EXCEEDED, mockEventConsumer.GetConsumeEvents()[0].Properties.(map[string]interface{})["code"])
assert.Equal(t, NewQuotaExceededError().Error(), mockEventConsumer.GetConsumeEvents()[0].Properties.(map[string]interface{})["message"])
}

func TestSendPaymentSync_App_BudgetExceeded_SettledPayment(t *testing.T) {
Expand Down
10 changes: 5 additions & 5 deletions transactions/isolated_app_payments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ func TestSendPaymentSync_IsolatedApp_BalanceInsufficient(t *testing.T) {
assert.ErrorIs(t, err, NewInsufficientBalanceError())
assert.Nil(t, transaction)

assert.Equal(t, 1, len(mockEventConsumer.ConsumedEvents))
assert.Equal(t, "nwc_permission_denied", mockEventConsumer.ConsumedEvents[0].Event)
assert.Equal(t, app.Name, mockEventConsumer.ConsumedEvents[0].Properties.(map[string]interface{})["app_name"])
assert.Equal(t, constants.ERROR_INSUFFICIENT_BALANCE, mockEventConsumer.ConsumedEvents[0].Properties.(map[string]interface{})["code"])
assert.Equal(t, NewInsufficientBalanceError().Error(), mockEventConsumer.ConsumedEvents[0].Properties.(map[string]interface{})["message"])
assert.Equal(t, 1, len(mockEventConsumer.GetConsumeEvents()))
assert.Equal(t, "nwc_permission_denied", mockEventConsumer.GetConsumeEvents()[0].Event)
assert.Equal(t, app.Name, mockEventConsumer.GetConsumeEvents()[0].Properties.(map[string]interface{})["app_name"])
assert.Equal(t, constants.ERROR_INSUFFICIENT_BALANCE, mockEventConsumer.GetConsumeEvents()[0].Properties.(map[string]interface{})["code"])
assert.Equal(t, NewInsufficientBalanceError().Error(), mockEventConsumer.GetConsumeEvents()[0].Properties.(map[string]interface{})["message"])
}

func TestSendPaymentSync_IsolatedApp_BalanceSufficient(t *testing.T) {
Expand Down
70 changes: 63 additions & 7 deletions transactions/payments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ func TestMarkSettled_Sent(t *testing.T) {

assert.Nil(t, err)
assert.Equal(t, constants.TRANSACTION_STATE_SETTLED, dbTransaction.State)
assert.Equal(t, 1, len(mockEventConsumer.ConsumedEvents))
assert.Equal(t, "nwc_payment_sent", mockEventConsumer.ConsumedEvents[0].Event)
settledTransaction := mockEventConsumer.ConsumedEvents[0].Properties.(*db.Transaction)
assert.Equal(t, 1, len(mockEventConsumer.GetConsumeEvents()))
assert.Equal(t, "nwc_payment_sent", mockEventConsumer.GetConsumeEvents()[0].Event)
settledTransaction := mockEventConsumer.GetConsumeEvents()[0].Properties.(*db.Transaction)
assert.Equal(t, &dbTransaction, settledTransaction)
}

Expand All @@ -103,9 +103,9 @@ func TestMarkSettled_Received(t *testing.T) {

assert.Nil(t, err)
assert.Equal(t, constants.TRANSACTION_STATE_SETTLED, dbTransaction.State)
assert.Equal(t, 1, len(mockEventConsumer.ConsumedEvents))
assert.Equal(t, "nwc_payment_received", mockEventConsumer.ConsumedEvents[0].Event)
settledTransaction := mockEventConsumer.ConsumedEvents[0].Properties.(*db.Transaction)
assert.Equal(t, 1, len(mockEventConsumer.GetConsumeEvents()))
assert.Equal(t, "nwc_payment_received", mockEventConsumer.GetConsumeEvents()[0].Event)
settledTransaction := mockEventConsumer.GetConsumeEvents()[0].Properties.(*db.Transaction)
assert.Equal(t, &dbTransaction, settledTransaction)
}

Expand Down Expand Up @@ -133,7 +133,63 @@ func TestDoNotMarkSettledTwice(t *testing.T) {

assert.Nil(t, err)
assert.Equal(t, settledAt, *dbTransaction.SettledAt)
assert.Zero(t, len(mockEventConsumer.ConsumedEvents))
assert.Zero(t, len(mockEventConsumer.GetConsumeEvents()))
}

func TestMarkFailed(t *testing.T) {
defer tests.RemoveTestService()
svc, err := tests.CreateTestService()
assert.NoError(t, err)

dbTransaction := db.Transaction{
State: constants.TRANSACTION_STATE_PENDING,
Type: constants.TRANSACTION_TYPE_OUTGOING,
PaymentHash: tests.MockLNClientTransaction.PaymentHash,
AmountMsat: 123000,
}
svc.DB.Create(&dbTransaction)

mockEventConsumer := tests.NewMockEventConsumer()
svc.EventPublisher.RegisterSubscriber(mockEventConsumer)
transactionsService := NewTransactionsService(svc.DB, svc.EventPublisher)
err = svc.DB.Transaction(func(tx *gorm.DB) error {
return transactionsService.markPaymentFailed(tx, &dbTransaction, "some routing error")
})

assert.Nil(t, err)
assert.Equal(t, constants.TRANSACTION_STATE_FAILED, dbTransaction.State)
assert.Equal(t, 1, len(mockEventConsumer.GetConsumeEvents()))
assert.Equal(t, "nwc_payment_failed", mockEventConsumer.GetConsumeEvents()[0].Event)
settledTransaction := mockEventConsumer.GetConsumeEvents()[0].Properties.(*db.Transaction)
assert.Equal(t, &dbTransaction, settledTransaction)
assert.Equal(t, "some routing error", settledTransaction.FailureReason)
}

func TestDoNotMarkFailedTwice(t *testing.T) {
defer tests.RemoveTestService()
svc, err := tests.CreateTestService()
assert.NoError(t, err)

updatedAt := time.Now().Add(time.Duration(-1) * time.Minute)
dbTransaction := db.Transaction{
State: constants.TRANSACTION_STATE_FAILED,
Type: constants.TRANSACTION_TYPE_OUTGOING,
PaymentHash: tests.MockLNClientTransaction.PaymentHash,
AmountMsat: 123000,
UpdatedAt: updatedAt,
}
svc.DB.Create(&dbTransaction)

mockEventConsumer := tests.NewMockEventConsumer()
svc.EventPublisher.RegisterSubscriber(mockEventConsumer)
transactionsService := NewTransactionsService(svc.DB, svc.EventPublisher)
err = svc.DB.Transaction(func(tx *gorm.DB) error {
return transactionsService.markPaymentFailed(tx, &dbTransaction, "some routing error")
})

assert.Nil(t, err)
assert.Equal(t, updatedAt, dbTransaction.UpdatedAt)
assert.Zero(t, len(mockEventConsumer.GetConsumeEvents()))
}

func TestSendPaymentSync_FailedRemovesFeeReserve(t *testing.T) {
Expand Down

0 comments on commit 7c35191

Please sign in to comment.