diff --git a/action/protocol/execution/evm/contract.go b/action/protocol/execution/evm/contract.go index bca8a17e4a..137347c325 100644 --- a/action/protocol/execution/evm/contract.go +++ b/action/protocol/execution/evm/contract.go @@ -44,14 +44,15 @@ type ( contract struct { *state.Account - async bool - dirtyCode bool // contract's code has been set - dirtyState bool // contract's account state has changed - code protocol.SerializableBytes // contract byte-code - root hash.Hash256 - committed map[hash.Hash256][]byte - sm protocol.StateManager - trie trie.Trie // storage trie of the contract + async bool + dirtyCode bool // contract's code has been set + dirtyState bool // contract's account state has changed + code protocol.SerializableBytes // contract byte-code + root hash.Hash256 + committed map[hash.Hash256][]byte + sm protocol.StateManager + trie trie.Trie // storage trie of the contract + storeForTrie *protocol.KvStoreForTrie } ) @@ -145,6 +146,11 @@ func (c *contract) Commit() error { } c.dirtyCode = false } + for _, opts := range c.storeForTrie.Stales() { + if _, err := c.sm.DelState(opts...); err != nil { + return errors.Wrapf(err, "Failed to delete stale key") + } + } return nil } @@ -180,14 +186,15 @@ func (c *contract) Snapshot() Contract { // newContract returns a Contract instance func newContract(addr hash.Hash160, account *state.Account, sm protocol.StateManager, enableAsync bool) (Contract, error) { c := &contract{ - Account: account, - root: account.Root, - committed: make(map[hash.Hash256][]byte), - sm: sm, - async: enableAsync, + Account: account, + root: account.Root, + committed: make(map[hash.Hash256][]byte), + sm: sm, + async: enableAsync, + storeForTrie: protocol.NewKVStoreForTrieWithStateManager(ContractKVNameSpace, sm), } options := []mptrie.Option{ - mptrie.KVStoreOption(protocol.NewKVStoreForTrieWithStateManager(ContractKVNameSpace, sm)), + mptrie.KVStoreOption(c.storeForTrie), mptrie.KeyLengthOption(len(hash.Hash256{})), mptrie.HashFuncOption(func(data []byte) []byte { h := hash.Hash256b(append(addr[:], data...)) diff --git a/action/protocol/kvstorefortrie.go b/action/protocol/kvstorefortrie.go index 78b8cc0a5a..ca7abf5259 100644 --- a/action/protocol/kvstorefortrie.go +++ b/action/protocol/kvstorefortrie.go @@ -16,9 +16,10 @@ import ( ) type ( - kvStoreForTrie struct { - nsOpt StateOption - sm StateManager + KvStoreForTrie struct { + nsOpt StateOption + sm StateManager + staleKeys map[string]struct{} } kvStoreForTrieWithStateReader struct { nsOpt StateOption @@ -27,36 +28,50 @@ type ( ) // NewKVStoreForTrieWithStateManager creates a trie.KVStore with state manager -func NewKVStoreForTrieWithStateManager(ns string, sm StateManager) trie.KVStore { - return &kvStoreForTrie{nsOpt: NamespaceOption(ns), sm: sm} +func NewKVStoreForTrieWithStateManager(ns string, sm StateManager) *KvStoreForTrie { + return &KvStoreForTrie{nsOpt: NamespaceOption(ns), sm: sm, staleKeys: make(map[string]struct{})} } -func (kv *kvStoreForTrie) Start(context.Context) error { +func (kv *KvStoreForTrie) Start(context.Context) error { return nil } -func (kv *kvStoreForTrie) Stop(context.Context) error { +func (kv *KvStoreForTrie) Stop(context.Context) error { return nil } -func (kv *kvStoreForTrie) Put(key []byte, value []byte) error { +func (kv *KvStoreForTrie) Put(key []byte, value []byte) error { var sb SerializableBytes sb = make([]byte, len(value)) copy(sb, value) _, err := kv.sm.PutState(sb, KeyOption(key), kv.nsOpt) + if err != nil { + return err + } + dk := string(key) + if _, ok := kv.staleKeys[dk]; ok { + delete(kv.staleKeys, dk) + } + return nil +} - return err +func (kv *KvStoreForTrie) Delete(key []byte) error { + dk := string(key) + if _, ok := kv.staleKeys[dk]; !ok { + kv.staleKeys[dk] = struct{}{} + } + return nil } -func (kv *kvStoreForTrie) Delete(key []byte) error { - _, err := kv.sm.DelState(KeyOption(key), kv.nsOpt) - if errors.Cause(err) == state.ErrStateNotExist { - return nil +func (kv *KvStoreForTrie) Stales() [][]StateOption { + var keys [][]StateOption + for k := range kv.staleKeys { + keys = append(keys, []StateOption{KeyOption([]byte(k)), kv.nsOpt}) } - return err + return keys } -func (kv *kvStoreForTrie) Get(key []byte) ([]byte, error) { +func (kv *KvStoreForTrie) Get(key []byte) ([]byte, error) { var value SerializableBytes _, err := kv.sm.State(&value, KeyOption(key), kv.nsOpt) switch errors.Cause(err) { diff --git a/action/protocol/kvstorefortrie_test.go b/action/protocol/kvstorefortrie_test.go index 59e1782a21..aecd56b19f 100644 --- a/action/protocol/kvstorefortrie_test.go +++ b/action/protocol/kvstorefortrie_test.go @@ -11,8 +11,11 @@ import ( "encoding/hex" "testing" - "github.com/iotexproject/iotex-core/v2/state" "github.com/stretchr/testify/require" + + "github.com/iotexproject/iotex-core/v2/db/trie" + "github.com/iotexproject/iotex-core/v2/db/trie/mptrie" + "github.com/iotexproject/iotex-core/v2/state" ) type inMemStateManager struct { @@ -113,7 +116,8 @@ func TestKVStoreForTrie(t *testing.T) { key := []byte("key") value := SerializableBytes("value") sm := newInMemStateManager() - kvstore := NewKVStoreForTrieWithStateManager(ns, sm) + var kvstore trie.KVStore + kvstore = NewKVStoreForTrieWithStateManager(ns, sm) require.NoError(kvstore.Start(context.Background())) require.NoError(kvstore.Stop(context.Background())) _, err := kvstore.Get(key) @@ -139,3 +143,38 @@ func TestKVStoreForTrie(t *testing.T) { require.True(bytes.Equal(fromStore, value)) } + +func TestHistoryKVStoreForTrie(t *testing.T) { + r := require.New(t) + ns := "namespace" + key := []byte("key") + value := SerializableBytes("value") + sm := newInMemStateManager() + kvstore := NewKVStoreForTrieWithStateManager(ns, sm) + + trie, err := mptrie.New(mptrie.KVStoreOption(kvstore), mptrie.KeyLengthOption(3)) + r.NoError(err) + r.NoError(trie.Start(context.Background())) + defer trie.Stop(context.Background()) + + root0, err := trie.RootHash() + r.NoError(err) + t.Logf("root: %x\n", root0) + + r.NoError(trie.Upsert(key, value)) + root1, err := trie.RootHash() + r.NoError(err) + t.Logf("root: %x\n", root1) + + r.NoError(trie.Upsert(key, SerializableBytes("value2"))) + root2, err := trie.RootHash() + r.NoError(err) + t.Logf("root: %x\n", root2) + + r.NoError(trie.SetRootHash(root1)) + for _, opt := range kvstore.Stales() { + for _, o := range opt { + t.Logf("stale: %#v\n", o) + } + } +}