diff --git a/x/evm/keeper/keeper_test.go b/x/evm/keeper/keeper_test.go index e4b3eb5c85..f61bf68a5c 100644 --- a/x/evm/keeper/keeper_test.go +++ b/x/evm/keeper/keeper_test.go @@ -32,10 +32,10 @@ import ( "github.com/evmos/ethermint/crypto/ethsecp256k1" "github.com/evmos/ethermint/server/config" "github.com/evmos/ethermint/tests" + "github.com/evmos/ethermint/testutil" ethermint "github.com/evmos/ethermint/types" "github.com/evmos/ethermint/x/evm/statedb" "github.com/evmos/ethermint/x/evm/types" - evmtypes "github.com/evmos/ethermint/x/evm/types" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -50,8 +50,6 @@ import ( "github.com/cometbft/cometbft/version" ) -var testTokens = sdkmath.NewIntWithDecimal(1000, 18) - type KeeperTestSuite struct { suite.Suite @@ -225,7 +223,7 @@ func (suite *KeeperTestSuite) SetupAppWithT(checkTx bool, t require.TestingT) { suite.clientCtx = client.Context{}.WithTxConfig(suite.app.TxConfig()) suite.ethSigner = ethtypes.LatestSignerForChainID(suite.app.EvmKeeper.ChainID()) suite.appCodec = suite.app.AppCodec() - suite.denom = evmtypes.DefaultEVMDenom + suite.denom = types.DefaultEVMDenom } func (suite *KeeperTestSuite) EvmDenom() string { @@ -535,3 +533,58 @@ func (suite *KeeperTestSuite) TestGetAccountOrEmpty() { }) } } + +func (suite *KeeperTestSuite) TestRevertByPrecompileSnapshot() { + db := suite.StateDB() + + // snapshot id for journal + rev := db.Snapshot() + + ctx, err := db.GetCacheContext() + suite.NoError(err) + + snapshotMultiStore, err := db.MultiStoreSnapshot() + suite.NoError(err) + snapshotEvents := suite.ctx.EventManager().Events() + db.AddPrecompileSnapshot(snapshotMultiStore, snapshotEvents) + + // manipulate statedb(evm) + evmAddr, priv := tests.NewAddrKey() + key1 := common.BigToHash(big.NewInt(1)) + value1 := common.BigToHash(big.NewInt(2)) + key2 := common.BigToHash(big.NewInt(3)) + value2 := common.BigToHash(big.NewInt(4)) + db.SetState(evmAddr, key1, value1) + db.SetState(evmAddr, key2, value2) + + suite.Equal(value1, db.GetState(evmAddr, key1)) + suite.Equal(value2, db.GetState(evmAddr, key2)) + + // manipulate bank keeper(sdk) + addr1 := sdk.AccAddress(priv.PubKey().Address().Bytes()) + denom := "testdenom" + testutil.FundAccount(suite.app.BankKeeper, ctx, addr1, sdk.NewCoins(sdk.NewCoin(denom, sdkmath.NewInt(1000000)))) + addr2 := sdk.AccAddress(tests.GenerateAddress().Bytes()) + suite.app.BankKeeper.SendCoins(ctx, addr1, addr2, sdk.NewCoins(sdk.NewCoin(denom, sdkmath.NewInt(1000)))) + + suite.Equal(sdk.NewCoin(denom, sdkmath.NewInt(999000)), suite.app.BankKeeper.GetBalance(ctx, addr1, denom)) + suite.Equal(sdk.NewCoin(denom, sdkmath.NewInt(1000)), suite.app.BankKeeper.GetBalance(ctx, addr2, denom)) + + // revert to snapshot + db.RevertToSnapshot(rev) + + suite.Equal(common.Hash{}, db.GetState(evmAddr, key1)) + suite.Equal(common.Hash{}, db.GetState(evmAddr, key2)) + + suite.Equal(sdk.NewCoin(denom, sdkmath.NewInt(999000)), suite.app.BankKeeper.GetBalance(ctx, addr1, denom)) + suite.Equal(sdk.NewCoin(denom, sdkmath.NewInt(1000)), suite.app.BankKeeper.GetBalance(ctx, addr2, denom)) + + // commit changes(revert sdk state) + db.Commit() + + suite.Equal(common.Hash{}, db.GetState(evmAddr, key1)) + suite.Equal(common.Hash{}, db.GetState(evmAddr, key2)) + + suite.Equal(sdk.NewCoin(denom, sdkmath.ZeroInt()), suite.app.BankKeeper.GetBalance(suite.ctx, addr1, denom)) + suite.Equal(sdk.NewCoin(denom, sdkmath.ZeroInt()), suite.app.BankKeeper.GetBalance(suite.ctx, addr2, denom)) +} diff --git a/x/evm/statedb/journal.go b/x/evm/statedb/journal.go index bbc9d39866..a67143585d 100644 --- a/x/evm/statedb/journal.go +++ b/x/evm/statedb/journal.go @@ -252,7 +252,13 @@ func (ch accessListAddSlotChange) Dirtied() *common.Address { } func (ch precompileChange) Revert(s *StateDB) { - s.RevertWithMultiStoreSnapshot(ch.ms) + s.cacheCtx = s.cacheCtx.WithMultiStore(ch.ms) + + // Necessary to revert the sdk state + s.writeCache = func() { + s.cacheCtx.EventManager().EmitEvents(ch.es) + ch.ms.CacheMultiStore().Write() + } } func (ch precompileChange) Dirtied() *common.Address { diff --git a/x/evm/statedb/statedb.go b/x/evm/statedb/statedb.go index 8d1801be40..b91ebb4108 100644 --- a/x/evm/statedb/statedb.go +++ b/x/evm/statedb/statedb.go @@ -329,32 +329,25 @@ func (s *StateDB) MultiStoreSnapshot() (storetypes.CacheMultiStore, error) { return snapshot, nil } -func (s *StateDB) RevertWithMultiStoreSnapshot(snapshot storetypes.MultiStore) { - s.cacheCtx = s.cacheCtx. - WithMultiStore(snapshot). - WithEventManager(sdk.NewEventManager()) -} - -// If revert is occurred, the snapshot of journal is overwrited to MultiStore of ctx. -// The events is just for debug. -func (s *StateDB) PostPrecompileProcessing(snapshot storetypes.MultiStore, events sdk.Events, contract common.Address, converter EventConverter) { +func (s *StateDB) ProcessPrecompileEvents(contract common.Address, events sdk.Events, converter EventConverter) { // convert native events to evm logs - if converter != nil && len(events) > 0 { - for _, event := range events { - log, err := converter(event) - if err != nil { - s.ctx.Logger().Error("failed to convert event", "err", err) - continue - } - if log == nil { - continue - } - - log.Address = contract - s.AddLog(log) + for _, event := range events { + log, err := converter(event) + if err != nil { + s.ctx.Logger().Error("failed to convert event", "err", err) + continue + } + if log == nil { + continue } + + log.Address = contract + s.AddLog(log) } +} +// If revert is occurred, the snapshot of journal is overwritten. +func (s *StateDB) AddPrecompileSnapshot(snapshot storetypes.MultiStore, events sdk.Events) { s.journal.append(precompileChange{snapshot, events}) }