diff --git a/data/txHandler.go b/data/txHandler.go index 3d20e95acd..7851889378 100644 --- a/data/txHandler.go +++ b/data/txHandler.go @@ -599,6 +599,18 @@ func (handler *TxHandler) processIncomingTxn(rawmsg network.IncomingMessage) net var err error var capguard *util.ErlCapacityGuard + accepted := false + defer func() { + // if we failed to put the item onto the backlog, we should release the capacity if any + if !accepted { + if capguard != nil { + if capErr := capguard.Release(); capErr != nil { + logging.Base().Warnf("Failed to release capacity to ElasticRateLimiter: %v", capErr) + } + } + } + }() + if handler.erl != nil { congestedERL := float64(cap(handler.backlogQueue))*handler.backlogCongestionThreshold < float64(len(handler.backlogQueue)) // consume a capacity unit @@ -679,6 +691,7 @@ func (handler *TxHandler) processIncomingTxn(rawmsg network.IncomingMessage) net unverifiedTxGroupHash: canonicalKey, capguard: capguard, }: + accepted = true default: // if we failed here we want to increase the corresponding metric. It might suggest that we // want to increase the queue size. diff --git a/data/txHandler_test.go b/data/txHandler_test.go index 5eb40741ee..896fbb161d 100644 --- a/data/txHandler_test.go +++ b/data/txHandler_test.go @@ -28,6 +28,7 @@ import ( "runtime/pprof" "strings" "sync" + "sync/atomic" "testing" "time" @@ -794,7 +795,7 @@ func TestTxHandlerProcessIncomingCensoring(t *testing.T) { // makeTestTxHandlerOrphaned creates a tx handler without any backlog consumer. // It is caller responsibility to run a consumer thread. func makeTestTxHandlerOrphaned(backlogSize int) *TxHandler { - return makeTestTxHandlerOrphanedWithContext(context.Background(), txBacklogSize, txBacklogSize, txHandlerConfig{true, false}, 0) + return makeTestTxHandlerOrphanedWithContext(context.Background(), backlogSize, backlogSize, txHandlerConfig{true, false}, 0) } func makeTestTxHandlerOrphanedWithContext(ctx context.Context, backlogSize int, cacheSize int, txHandlerConfig txHandlerConfig, refreshInterval time.Duration) *TxHandler { @@ -2681,3 +2682,65 @@ func TestTxHandlerAppRateLimiter(t *testing.T) { msg := <-handler.backlogQueue require.Equal(t, msg.rawmsg.Data, blob, blob) } + +// TestTxHandlerCapGuard checks there is no cap guard leak in case of invalid input. +func TestTxHandlerCapGuard(t *testing.T) { + partitiontest.PartitionTest(t) + t.Parallel() + + const numUsers = 10 + addresses, secrets, genesis := makeTestGenesisAccounts(t, numUsers) + genBal := bookkeeping.MakeGenesisBalances(genesis, sinkAddr, poolAddr) + ledgerName := fmt.Sprintf("%s-mem", t.Name()) + const inMem = true + log := logging.TestingLog(t) + log.SetLevel(logging.Error) + + cfg := config.GetDefaultLocal() + cfg.EnableTxBacklogRateLimiting = true + cfg.EnableTxBacklogAppRateLimiting = false + cfg.TxIncomingFilteringFlags = 0 + cfg.TxBacklogServiceRateWindowSeconds = 1 + cfg.TxBacklogReservedCapacityPerPeer = 1 + cfg.IncomingConnectionsLimit = 1 + cfg.TxBacklogSize = 3 + + ledger, err := LoadLedger(log, ledgerName, inMem, protocol.ConsensusCurrentVersion, genBal, genesisID, genesisHash, nil, cfg) + require.NoError(t, err) + defer ledger.Close() + + handler, err := makeTestTxHandler(ledger, cfg) + require.NoError(t, err) + defer handler.txVerificationPool.Shutdown() + defer close(handler.streamVerifierDropped) + + tx := transactions.Transaction{ + Type: protocol.PaymentTx, + Header: transactions.Header{ + Sender: addresses[0], + Fee: basics.MicroAlgos{Raw: proto.MinTxnFee * 2}, + FirstValid: 0, + LastValid: basics.Round(proto.MaxTxnLife), + }, + PaymentTxnFields: transactions.PaymentTxnFields{ + Receiver: addresses[1], + Amount: basics.MicroAlgos{Raw: 1000}, + }, + } + + signedTx := tx.Sign(secrets[0]) + blob := protocol.Encode(&signedTx) + blob[0]++ // make it invalid + + var completed atomic.Bool + go func() { + for i := 0; i < 10; i++ { + outgoing := handler.processIncomingTxn(network.IncomingMessage{Data: blob, Sender: mockSender{}}) + require.Equal(t, network.OutgoingMessage{Action: network.Disconnect}, outgoing) + require.Equal(t, 0, len(handler.backlogQueue)) + } + completed.Store(true) + }() + + require.Eventually(t, func() bool { return completed.Load() }, 1*time.Second, 10*time.Millisecond) +}