Skip to content

Commit

Permalink
Merge pull request #277 from OffchainLabs/better-recreate-state-for-rpc
Browse files Browse the repository at this point in the history
use StateAtBlock and reference states when recreating missing state
  • Loading branch information
tsahee authored Mar 13, 2024
2 parents a104baf + 088149d commit e5d8587
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 48 deletions.
88 changes: 79 additions & 9 deletions arbitrum/apibackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (
"github.com/ethereum/go-ethereum/eth"
"github.com/ethereum/go-ethereum/eth/tracers"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/trie"

"github.com/ethereum/go-ethereum/accounts"
"github.com/ethereum/go-ethereum/common"
Expand All @@ -22,6 +24,7 @@ import (
"github.com/ethereum/go-ethereum/core/bloombits"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/filters"
Expand All @@ -32,6 +35,13 @@ import (
"github.com/ethereum/go-ethereum/rpc"
)

var (
liveStatesReferencedCounter = metrics.NewRegisteredCounter("arb/apibackend/states/live/referenced", nil)
liveStatesDereferencedCounter = metrics.NewRegisteredCounter("arb/apibackend/states/live/dereferenced", nil)
recreatedStatesReferencedCounter = metrics.NewRegisteredCounter("arb/apibackend/states/recreated/referenced", nil)
recreatedStatesDereferencedCounter = metrics.NewRegisteredCounter("arb/apibackend/states/recreated/dereferenced", nil)
)

type APIBackend struct {
b *Backend

Expand Down Expand Up @@ -444,21 +454,75 @@ func (a *APIBackend) stateAndHeaderFromHeader(ctx context.Context, header *types
return nil, header, types.ErrUseFallback
}
bc := a.BlockChain()
stateFor := func(header *types.Header) (*state.StateDB, error) {
return bc.StateAt(header.Root)
stateFor := func(db state.Database, snapshots *snapshot.Tree) func(header *types.Header) (*state.StateDB, StateReleaseFunc, error) {
return func(header *types.Header) (*state.StateDB, StateReleaseFunc, error) {
if header.Root != (common.Hash{}) {
// Try referencing the root, if it isn't in dirties cache then Reference will have no effect
db.TrieDB().Reference(header.Root, common.Hash{})
}
statedb, err := state.New(header.Root, db, snapshots)
if err != nil {
return nil, nil, err
}
if header.Root != (common.Hash{}) {
headerRoot := header.Root
return statedb, func() { db.TrieDB().Dereference(headerRoot) }, nil
}
return statedb, NoopStateRelease, nil
}
}
state, lastHeader, err := FindLastAvailableState(ctx, bc, stateFor, header, nil, a.b.config.MaxRecreateStateDepth)
liveState, liveStateRelease, err := stateFor(bc.StateCache(), bc.Snapshots())(header)
if err == nil {
liveStatesReferencedCounter.Inc(1)
liveState.SetArbFinalizer(func(*state.ArbitrumExtraData) {
liveStateRelease()
liveStatesDereferencedCounter.Inc(1)
})
return liveState, header, nil
}
// else err != nil => we don't need to call liveStateRelease

// Create an ephemeral trie.Database for isolating the live one
// note: triedb cleans cache is disabled in trie.HashDefaults
// note: only states committed to diskdb can be found as we're creating new triedb
// note: snapshots are not used here
ephemeral := state.NewDatabaseWithConfig(a.ChainDb(), trie.HashDefaults)
lastState, lastHeader, lastStateRelease, err := FindLastAvailableState(ctx, bc, stateFor(ephemeral, nil), header, nil, a.b.config.MaxRecreateStateDepth)
if err != nil {
return nil, nil, err
}
// make sure that we haven't found the state in diskdb
if lastHeader == header {
return state, header, nil
}
state, err = AdvanceStateUpToBlock(ctx, bc, state, header, lastHeader, nil)
liveStatesReferencedCounter.Inc(1)
lastState.SetArbFinalizer(func(*state.ArbitrumExtraData) {
lastStateRelease()
liveStatesDereferencedCounter.Inc(1)
})
return lastState, header, nil
}
defer lastStateRelease()
targetBlock := bc.GetBlockByNumber(header.Number.Uint64())
if targetBlock == nil {
return nil, nil, errors.New("target block not found")
}
lastBlock := bc.GetBlockByNumber(lastHeader.Number.Uint64())
if lastBlock == nil {
return nil, nil, errors.New("last block not found")
}
reexec := uint64(0)
checkLive := false
preferDisk := false // preferDisk is ignored in this case
statedb, release, err := eth.NewArbEthereum(a.b.arb.BlockChain(), a.ChainDb()).StateAtBlock(ctx, targetBlock, reexec, lastState, lastBlock, checkLive, preferDisk)
if err != nil {
return nil, nil, err
return nil, nil, fmt.Errorf("failed to recreate state: %w", err)
}
return state, header, err
// we are setting finalizer instead of returning a StateReleaseFunc to avoid changing ethapi.Backend interface to minimize diff to upstream
recreatedStatesReferencedCounter.Inc(1)
statedb.SetArbFinalizer(func(*state.ArbitrumExtraData) {
release()
recreatedStatesDereferencedCounter.Inc(1)
})
return statedb, header, err
}

func (a *APIBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*state.StateDB, *types.Header, error) {
Expand All @@ -468,6 +532,12 @@ func (a *APIBackend) StateAndHeaderByNumber(ctx context.Context, number rpc.Bloc

func (a *APIBackend) StateAndHeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*state.StateDB, *types.Header, error) {
header, err := a.HeaderByNumberOrHash(ctx, blockNrOrHash)
hash, ishash := blockNrOrHash.Hash()
bc := a.BlockChain()
// check if we are not trying to get recent state that is not yet triedb referenced or committed in Blockchain.writeBlockWithState
if ishash && header.Number.Cmp(bc.CurrentBlock().Number) > 0 && bc.GetCanonicalHash(header.Number.Uint64()) != hash {
return nil, nil, errors.New("requested block ahead of current block and the hash is not currently canonical")
}
return a.stateAndHeaderFromHeader(ctx, header, err)
}

Expand All @@ -476,7 +546,7 @@ func (a *APIBackend) StateAtBlock(ctx context.Context, block *types.Block, reexe
return nil, nil, types.ErrUseFallback
}
// DEV: This assumes that `StateAtBlock` only accesses the blockchain and chainDb fields
return eth.NewArbEthereum(a.b.arb.BlockChain(), a.ChainDb()).StateAtBlock(ctx, block, reexec, base, checkLive, preferDisk)
return eth.NewArbEthereum(a.b.arb.BlockChain(), a.ChainDb()).StateAtBlock(ctx, block, reexec, base, nil, checkLive, preferDisk)
}

func (a *APIBackend) StateAtTransaction(ctx context.Context, block *types.Block, txIndex int, reexec uint64) (*core.Message, vm.BlockContext, *state.StateDB, tracers.StateReleaseFunc, error) {
Expand Down
7 changes: 6 additions & 1 deletion arbitrum/recordingdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,12 @@ func (r *RecordingDatabase) PreimagesFromRecording(chainContextIf core.ChainCont
}

func (r *RecordingDatabase) GetOrRecreateState(ctx context.Context, header *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, error) {
state, currentHeader, err := FindLastAvailableState(ctx, r.bc, r.StateFor, header, logFunc, -1)
stateFor := func(header *types.Header) (*state.StateDB, StateReleaseFunc, error) {
state, err := r.StateFor(header)
// we don't use the release functor pattern here yet
return state, NoopStateRelease, err
}
state, currentHeader, _, err := FindLastAvailableState(ctx, r.bc, stateFor, header, logFunc, -1)
if err != nil {
return nil, err
}
Expand Down
24 changes: 15 additions & 9 deletions arbitrum/recreatestate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,58 +9,64 @@ import (
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/eth/tracers"
"github.com/pkg/errors"
)

var (
ErrDepthLimitExceeded = errors.New("state recreation l2 gas depth limit exceeded")
)

type StateReleaseFunc tracers.StateReleaseFunc

var NoopStateRelease StateReleaseFunc = func() {}

type StateBuildingLogFunction func(targetHeader, header *types.Header, hasState bool)
type StateForHeaderFunction func(header *types.Header) (*state.StateDB, error)
type StateForHeaderFunction func(header *types.Header) (*state.StateDB, StateReleaseFunc, error)

// finds last available state and header checking it first for targetHeader then looking backwards
// if maxDepthInL2Gas is positive, it constitutes a limit for cumulative l2 gas used of the traversed blocks
// else if maxDepthInL2Gas is -1, the traversal depth is not limited
// otherwise only targetHeader state is checked and no search is performed
func FindLastAvailableState(ctx context.Context, bc *core.BlockChain, stateFor StateForHeaderFunction, targetHeader *types.Header, logFunc StateBuildingLogFunction, maxDepthInL2Gas int64) (*state.StateDB, *types.Header, error) {
func FindLastAvailableState(ctx context.Context, bc *core.BlockChain, stateFor StateForHeaderFunction, targetHeader *types.Header, logFunc StateBuildingLogFunction, maxDepthInL2Gas int64) (*state.StateDB, *types.Header, StateReleaseFunc, error) {
genesis := bc.Config().ArbitrumChainParams.GenesisBlockNum
currentHeader := targetHeader
var state *state.StateDB
var err error
var l2GasUsed uint64
release := NoopStateRelease
for ctx.Err() == nil {
lastHeader := currentHeader
state, err = stateFor(currentHeader)
state, release, err = stateFor(currentHeader)
if err == nil {
break
}
if maxDepthInL2Gas > 0 {
receipts := bc.GetReceiptsByHash(currentHeader.Hash())
if receipts == nil {
return nil, lastHeader, fmt.Errorf("failed to get receipts for hash %v", currentHeader.Hash())
return nil, lastHeader, nil, fmt.Errorf("failed to get receipts for hash %v", currentHeader.Hash())
}
for _, receipt := range receipts {
l2GasUsed += receipt.GasUsed - receipt.GasUsedForL1
}
if l2GasUsed > uint64(maxDepthInL2Gas) {
return nil, lastHeader, ErrDepthLimitExceeded
return nil, lastHeader, nil, ErrDepthLimitExceeded
}
} else if maxDepthInL2Gas != InfiniteMaxRecreateStateDepth {
return nil, lastHeader, err
return nil, lastHeader, nil, err
}
if logFunc != nil {
logFunc(targetHeader, currentHeader, false)
}
if currentHeader.Number.Uint64() <= genesis {
return nil, lastHeader, errors.Wrap(err, fmt.Sprintf("moved beyond genesis looking for state %d, genesis %d", targetHeader.Number.Uint64(), genesis))
return nil, lastHeader, nil, errors.Wrap(err, fmt.Sprintf("moved beyond genesis looking for state %d, genesis %d", targetHeader.Number.Uint64(), genesis))
}
currentHeader = bc.GetHeader(currentHeader.ParentHash, currentHeader.Number.Uint64()-1)
if currentHeader == nil {
return nil, lastHeader, fmt.Errorf("chain doesn't contain parent of block %d hash %v", lastHeader.Number, lastHeader.Hash())
return nil, lastHeader, nil, fmt.Errorf("chain doesn't contain parent of block %d hash %v", lastHeader.Number, lastHeader.Hash())
}
}
return state, currentHeader, ctx.Err()
return state, currentHeader, release, ctx.Err()
}

func AdvanceStateByBlock(ctx context.Context, bc *core.BlockChain, state *state.StateDB, targetHeader *types.Header, blockToRecreate uint64, prevBlockHash common.Hash, logFunc StateBuildingLogFunction) (*state.StateDB, *types.Block, error) {
Expand Down
7 changes: 4 additions & 3 deletions core/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,8 @@ func (bc *BlockChain) Stop() {
// - HEAD: So we don't need to reprocess any blocks in the general case
// - HEAD-1: So we don't do large reorgs if our HEAD becomes an uncle
// - HEAD-127: So we have a hard limit on the number of blocks reexecuted
if !bc.cacheConfig.TrieDirtyDisabled {
// It applies for both full node and sparse archive node
if !bc.cacheConfig.TrieDirtyDisabled || bc.cacheConfig.MaxNumberOfBlocksToSkipStateSaving > 0 || bc.cacheConfig.MaxAmountOfGasToSkipStateSaving > 0 {
triedb := bc.triedb

for _, offset := range []uint64{0, 1, bc.cacheConfig.TriesInMemory - 1, math.MaxUint64} {
Expand Down Expand Up @@ -1496,7 +1497,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
return nil
}
// If we're running an archive node, flush
// If MaxNumberOfBlocksToSkipStateSaving or MaxAmountOfGasToSkipStateSaving is not zero, then flushing of some blocks will be skipped:
// Sparse archive: if MaxNumberOfBlocksToSkipStateSaving or MaxAmountOfGasToSkipStateSaving is not zero, then flushing of some blocks will be skipped:
// * at most MaxNumberOfBlocksToSkipStateSaving block state commits will be skipped
// * sum of gas used in skipped blocks will be at most MaxAmountOfGasToSkipStateSaving
archiveNode := bc.cacheConfig.TrieDirtyDisabled
Expand Down Expand Up @@ -1526,7 +1527,7 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
// we are skipping saving the trie to diskdb, so we need to keep the trie in memory and garbage collect it later
}

// Full node or archive node that's not keeping all states, do proper garbage collection
// Full node or sparse archive node that's not keeping all states, do proper garbage collection
bc.triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive
bc.triegc.Push(trieGcEntry{root, block.Header().Time}, -int64(block.NumberU64()))

Expand Down
29 changes: 16 additions & 13 deletions core/state/statedb.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ type revision struct {
// must be created with new root and updated database for accessing post-
// commit states.
type StateDB struct {
// Arbitrum: track the total balance change across all accounts
unexpectedBalanceDelta *big.Int
arbExtraData *ArbitrumExtraData // must be a pointer - can't be a part of StateDB allocation, otherwise its finalizer might not get called

db Database
prefetcher *triePrefetcher
Expand Down Expand Up @@ -155,7 +154,9 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error)
return nil, err
}
sdb := &StateDB{
unexpectedBalanceDelta: new(big.Int),
arbExtraData: &ArbitrumExtraData{
unexpectedBalanceDelta: new(big.Int),
},

db: db,
trie: tr,
Expand Down Expand Up @@ -395,7 +396,7 @@ func (s *StateDB) HasSelfDestructed(addr common.Address) bool {
func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
s.unexpectedBalanceDelta.Add(s.unexpectedBalanceDelta, amount)
s.arbExtraData.unexpectedBalanceDelta.Add(s.arbExtraData.unexpectedBalanceDelta, amount)
stateObject.AddBalance(amount)
}
}
Expand All @@ -404,7 +405,7 @@ func (s *StateDB) AddBalance(addr common.Address, amount *big.Int) {
func (s *StateDB) SubBalance(addr common.Address, amount *big.Int) {
stateObject := s.GetOrNewStateObject(addr)
if stateObject != nil {
s.unexpectedBalanceDelta.Sub(s.unexpectedBalanceDelta, amount)
s.arbExtraData.unexpectedBalanceDelta.Sub(s.arbExtraData.unexpectedBalanceDelta, amount)
stateObject.SubBalance(amount)
}
}
Expand All @@ -416,8 +417,8 @@ func (s *StateDB) SetBalance(addr common.Address, amount *big.Int) {
amount = big.NewInt(0)
}
prevBalance := stateObject.Balance()
s.unexpectedBalanceDelta.Add(s.unexpectedBalanceDelta, amount)
s.unexpectedBalanceDelta.Sub(s.unexpectedBalanceDelta, prevBalance)
s.arbExtraData.unexpectedBalanceDelta.Add(s.arbExtraData.unexpectedBalanceDelta, amount)
s.arbExtraData.unexpectedBalanceDelta.Sub(s.arbExtraData.unexpectedBalanceDelta, prevBalance)
stateObject.SetBalance(amount)
}
}
Expand All @@ -426,7 +427,7 @@ func (s *StateDB) ExpectBalanceBurn(amount *big.Int) {
if amount.Sign() < 0 {
panic(fmt.Sprintf("ExpectBalanceBurn called with negative amount %v", amount))
}
s.unexpectedBalanceDelta.Add(s.unexpectedBalanceDelta, amount)
s.arbExtraData.unexpectedBalanceDelta.Add(s.arbExtraData.unexpectedBalanceDelta, amount)
}

func (s *StateDB) SetNonce(addr common.Address, nonce uint64) {
Expand Down Expand Up @@ -488,7 +489,7 @@ func (s *StateDB) SelfDestruct(addr common.Address) {
})

stateObject.markSelfdestructed()
s.unexpectedBalanceDelta.Sub(s.unexpectedBalanceDelta, stateObject.data.Balance)
s.arbExtraData.unexpectedBalanceDelta.Sub(s.arbExtraData.unexpectedBalanceDelta, stateObject.data.Balance)

stateObject.data.Balance = new(big.Int)
}
Expand Down Expand Up @@ -726,7 +727,9 @@ func (s *StateDB) CreateAccount(addr common.Address) {
func (s *StateDB) Copy() *StateDB {
// Copy all the basic fields, initialize the memory ones
state := &StateDB{
unexpectedBalanceDelta: new(big.Int).Set(s.unexpectedBalanceDelta),
arbExtraData: &ArbitrumExtraData{
unexpectedBalanceDelta: new(big.Int).Set(s.arbExtraData.unexpectedBalanceDelta),
},

db: s.db,
trie: s.db.CopyTrie(s.trie),
Expand Down Expand Up @@ -831,7 +834,7 @@ func (s *StateDB) Copy() *StateDB {
func (s *StateDB) Snapshot() int {
id := s.nextRevisionId
s.nextRevisionId++
s.validRevisions = append(s.validRevisions, revision{id, s.journal.length(), new(big.Int).Set(s.unexpectedBalanceDelta)})
s.validRevisions = append(s.validRevisions, revision{id, s.journal.length(), new(big.Int).Set(s.arbExtraData.unexpectedBalanceDelta)})
return id
}

Expand All @@ -846,7 +849,7 @@ func (s *StateDB) RevertToSnapshot(revid int) {
}
revision := s.validRevisions[idx]
snapshot := revision.journalIndex
s.unexpectedBalanceDelta = new(big.Int).Set(revision.unexpectedBalanceDelta)
s.arbExtraData.unexpectedBalanceDelta = new(big.Int).Set(revision.unexpectedBalanceDelta)

// Replay the journal to undo changes and remove invalidated snapshots
s.journal.revert(s, snapshot)
Expand Down Expand Up @@ -1322,7 +1325,7 @@ func (s *StateDB) Commit(block uint64, deleteEmptyObjects bool) (common.Hash, er
s.snap = nil
}

s.unexpectedBalanceDelta.Set(new(big.Int))
s.arbExtraData.unexpectedBalanceDelta.Set(new(big.Int))

if root == (common.Hash{}) {
root = types.EmptyRootHash
Expand Down
12 changes: 11 additions & 1 deletion core/state/statedb_arbitrum.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,30 @@ package state

import (
"math/big"
"runtime"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/trie"
)

type ArbitrumExtraData struct {
// track the total balance change across all accounts
unexpectedBalanceDelta *big.Int
}

func (s *StateDB) SetArbFinalizer(f func(*ArbitrumExtraData)) {
runtime.SetFinalizer(s.arbExtraData, f)
}

func (s *StateDB) GetCurrentTxLogs() []*types.Log {
return s.logs[s.thash]
}

// GetUnexpectedBalanceDelta returns the total unexpected change in balances since the last commit to the database.
func (s *StateDB) GetUnexpectedBalanceDelta() *big.Int {
return new(big.Int).Set(s.unexpectedBalanceDelta)
return new(big.Int).Set(s.arbExtraData.unexpectedBalanceDelta)
}

func (s *StateDB) GetSelfDestructs() []common.Address {
Expand Down
2 changes: 1 addition & 1 deletion eth/api_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ func (b *EthAPIBackend) StartMining() error {
}

func (b *EthAPIBackend) StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, readOnly bool, preferDisk bool) (*state.StateDB, tracers.StateReleaseFunc, error) {
return b.eth.stateAtBlock(ctx, block, reexec, base, readOnly, preferDisk)
return b.eth.stateAtBlock(ctx, block, reexec, base, nil, readOnly, preferDisk)
}

func (b *EthAPIBackend) StateAtTransaction(ctx context.Context, block *types.Block, txIndex int, reexec uint64) (*core.Message, vm.BlockContext, *state.StateDB, tracers.StateReleaseFunc, error) {
Expand Down
Loading

0 comments on commit e5d8587

Please sign in to comment.