Skip to content

Commit

Permalink
Merge pull request #2292 from OffchainLabs/wasmstore
Browse files Browse the repository at this point in the history
Wasmstore
  • Loading branch information
PlasmaPower authored May 16, 2024
2 parents 95910cf + cd09ae8 commit 6f73839
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 24 deletions.
90 changes: 78 additions & 12 deletions arbos/programs/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/log"
Expand Down Expand Up @@ -53,6 +54,24 @@ func activateProgram(
debug bool,
burner burn.Burner,
) (*activationInfo, error) {
info, asm, module, err := activateProgramInternal(db, program, codehash, wasm, page_limit, version, debug, burner.GasLeft())
if err != nil {
return nil, err
}
db.ActivateWasm(info.moduleHash, asm, module)
return info, nil
}

func activateProgramInternal(
db vm.StateDB,
program common.Address,
codehash common.Hash,
wasm []byte,
page_limit uint16,
version uint16,
debug bool,
gasLeft *uint64,
) (*activationInfo, []byte, []byte, error) {
output := &rustBytes{}
asmLen := usize(0)
moduleHash := &bytes32{}
Expand All @@ -69,7 +88,7 @@ func activateProgram(
&codeHash,
moduleHash,
stylusData,
(*u64)(burner.GasLeft()),
(*u64)(gasLeft),
))

data, msg, err := status.toResult(output.intoBytes(), debug)
Expand All @@ -78,9 +97,9 @@ func activateProgram(
log.Warn("activation failed", "err", err, "msg", msg, "program", program)
}
if errors.Is(err, vm.ErrExecutionReverted) {
return nil, fmt.Errorf("%w: %s", ErrProgramActivation, msg)
return nil, nil, nil, fmt.Errorf("%w: %s", ErrProgramActivation, msg)
}
return nil, err
return nil, nil, nil, err
}

hash := moduleHash.toHash()
Expand All @@ -95,13 +114,57 @@ func activateProgram(
asmEstimate: uint32(stylusData.asm_estimate),
footprint: uint16(stylusData.footprint),
}
db.ActivateWasm(hash, asm, module)
return info, err
return info, asm, module, err
}

func getLocalAsm(statedb vm.StateDB, moduleHash common.Hash, address common.Address, pagelimit uint16, time uint64, debugMode bool, program Program) ([]byte, error) {
localAsm, err := statedb.TryGetActivatedAsm(moduleHash)
if err == nil && len(localAsm) > 0 {
return localAsm, nil
}

codeHash := statedb.GetCodeHash(address)

wasm, err := getWasm(statedb, address)
if err != nil {
log.Error("Failed to reactivate program: getWasm", "address", address, "expected moduleHash", moduleHash, "err", err)
return nil, fmt.Errorf("failed to reactivate program address: %v err: %w", address, err)
}

unlimitedGas := uint64(0xffffffffffff)
// we know program is activated, so it must be in correct version and not use too much memory
info, asm, module, err := activateProgramInternal(statedb, address, codeHash, wasm, pagelimit, program.version, debugMode, &unlimitedGas)
if err != nil {
log.Error("failed to reactivate program", "address", address, "expected moduleHash", moduleHash, "err", err)
return nil, fmt.Errorf("failed to reactivate program address: %v err: %w", address, err)
}

if info.moduleHash != moduleHash {
log.Error("failed to reactivate program", "address", address, "expected moduleHash", moduleHash, "got", info.moduleHash)
return nil, fmt.Errorf("failed to reactivate program. address: %v, expected ModuleHash: %v", address, moduleHash)
}

currentHoursSince := hoursSinceArbitrum(time)
if currentHoursSince > program.activatedAt {
// stylus program is active on-chain, and was activated in the past
// so we store it directly to database
batch := statedb.Database().WasmStore().NewBatch()
rawdb.WriteActivation(batch, moduleHash, asm, module)
if err := batch.Write(); err != nil {
log.Error("failed writing re-activation to state", "address", address, "err", err)
}
} else {
// program activated recently, possibly in this eth_call
// store it to statedb. It will be stored to database if statedb is commited
statedb.ActivateWasm(info.moduleHash, asm, module)
}
return asm, nil
}

func callProgram(
address common.Address,
moduleHash common.Hash,
localAsm []byte,
scope *vm.ScopeContext,
interpreter *vm.EVMInterpreter,
tracingInfo *util.TracingInfo,
Expand All @@ -111,10 +174,9 @@ func callProgram(
memoryModel *MemoryModel,
) ([]byte, error) {
db := interpreter.Evm().StateDB
asm := db.GetActivatedAsm(moduleHash)
debug := stylusParams.DebugMode

if len(asm) == 0 {
if len(localAsm) == 0 {
log.Error("missing asm", "program", address, "module", moduleHash)
panic("missing asm")
}
Expand All @@ -128,7 +190,7 @@ func callProgram(

output := &rustBytes{}
status := userStatus(C.stylus_call(
goSlice(asm),
goSlice(localAsm),
goSlice(calldata),
stylusParams.encode(),
evmApi.cNative,
Expand Down Expand Up @@ -159,11 +221,15 @@ func handleReqImpl(apiId usize, req_type u32, data *rustSlice, costPtr *u64, out

// Caches a program in Rust. We write a record so that we can undo on revert.
// For gas estimation and eth_call, we ignore permanent updates and rely on Rust's LRU.
func cacheProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, runMode core.MessageRunMode) {
func cacheProgram(db vm.StateDB, module common.Hash, program Program, params *StylusParams, debug bool, time uint64, runMode core.MessageRunMode) {
if runMode == core.MessageCommitMode {
asm := db.GetActivatedAsm(module)
state.CacheWasmRust(asm, module, version, debug)
db.RecordCacheWasm(state.CacheWasm{ModuleHash: module, Version: version, Debug: debug})
// address is only used for logging
asm, err := getLocalAsm(db, module, common.Address{}, params.PageLimit, time, debug, program)
if err != nil {
panic("unable to recreate wasm")
}
state.CacheWasmRust(asm, module, program.version, debug)
db.RecordCacheWasm(state.CacheWasm{ModuleHash: module, Version: program.version, Debug: debug})
}
}

Expand Down
18 changes: 14 additions & 4 deletions arbos/programs/programs.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,13 @@ func (p Programs) ActivateProgram(evm *vm.EVM, address common.Address, runMode c
return 0, codeHash, common.Hash{}, nil, true, err
}

// replace the cached asm
// remove prev asm
if cached {
oldModuleHash, err := p.moduleHashes.Get(codeHash)
if err != nil {
return 0, codeHash, common.Hash{}, nil, true, err
}
evictProgram(statedb, oldModuleHash, currentVersion, debugMode, runMode, expired)
cacheProgram(statedb, info.moduleHash, stylusVersion, debugMode, runMode)
}
if err := p.moduleHashes.Set(codeHash, info.moduleHash); err != nil {
return 0, codeHash, common.Hash{}, nil, true, err
Expand All @@ -152,6 +151,11 @@ func (p Programs) ActivateProgram(evm *vm.EVM, address common.Address, runMode c
activatedAt: hoursSinceArbitrum(time),
cached: cached,
}
// replace the cached asm
if cached {
cacheProgram(statedb, info.moduleHash, programData, params, debugMode, time, runMode)
}

return stylusVersion, codeHash, info.moduleHash, dataFee, false, p.setProgram(codeHash, programData)
}

Expand Down Expand Up @@ -205,6 +209,12 @@ func (p Programs) CallProgram(
statedb.AddStylusPages(program.footprint)
defer statedb.SetStylusPagesOpen(open)

localAsm, err := getLocalAsm(statedb, moduleHash, contract.Address(), params.PageLimit, evm.Context.Time, debugMode, program)
if err != nil {
log.Crit("failed to get local wasm for activated program", "program", contract.Address())
return nil, err
}

evmData := &EvmData{
blockBasefee: common.BigToHash(evm.Context.BaseFee),
chainId: evm.ChainConfig().ChainID.Uint64(),
Expand All @@ -227,7 +237,7 @@ func (p Programs) CallProgram(
if contract.CodeAddr != nil {
address = *contract.CodeAddr
}
return callProgram(address, moduleHash, scope, interpreter, tracingInfo, calldata, evmData, goParams, model)
return callProgram(address, moduleHash, localAsm, scope, interpreter, tracingInfo, calldata, evmData, goParams, model)
}

func getWasm(statedb vm.StateDB, program common.Address) ([]byte, error) {
Expand Down Expand Up @@ -380,7 +390,7 @@ func (p Programs) SetProgramCached(
return err
}
if cache {
cacheProgram(db, moduleHash, program.version, debug, runMode)
cacheProgram(db, moduleHash, program, params, debug, time, runMode)
} else {
evictProgram(db, moduleHash, program.version, debug, runMode, expired)
}
Expand Down
7 changes: 6 additions & 1 deletion arbos/programs/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func activateProgram(
}

// stub any non-consensus, Rust-side caching updates
func cacheProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, mode core.MessageRunMode) {
func cacheProgram(db vm.StateDB, module common.Hash, program Program, params *StylusParams, debug bool, time uint64, runMode core.MessageRunMode) {
}
func evictProgram(db vm.StateDB, module common.Hash, version uint16, debug bool, mode core.MessageRunMode, forever bool) {
}
Expand Down Expand Up @@ -128,9 +128,14 @@ func startProgram(module uint32) uint32
//go:wasmimport programs send_response
func sendResponse(req_id uint32) uint32

func getLocalAsm(statedb vm.StateDB, moduleHash common.Hash, address common.Address, pagelimit uint16, time uint64, debugMode bool, program Program) ([]byte, error) {
return nil, nil
}

func callProgram(
address common.Address,
moduleHash common.Hash,
_localAsm []byte,
scope *vm.ScopeContext,
interpreter *vm.EVMInterpreter,
tracingInfo *util.TracingInfo,
Expand Down
18 changes: 14 additions & 4 deletions cmd/nitro/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,15 @@ func openInitializeChainDb(ctx context.Context, stack *node.Node, config *NodeCo
if !arbmath.BigEquals(chainConfig.ChainID, chainId) {
return nil, nil, fmt.Errorf("database has chain ID %v but config has chain ID %v (are you sure this database is for the right chain?)", chainConfig.ChainID, chainId)
}
chainDb, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false)
chainData, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false)
if err != nil {
return chainDb, nil, err
return nil, nil, err
}
wasmDb, err := stack.OpenDatabase("wasm", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, "wasm/", false)
if err != nil {
return nil, nil, err
}
chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmDb)
err = pruning.PruneChainDb(ctx, chainDb, stack, &config.Init, cacheConfig, l1Client, rollupAddrs, config.Node.ValidatorRequired())
if err != nil {
return chainDb, nil, fmt.Errorf("error pruning: %w", err)
Expand Down Expand Up @@ -230,10 +235,15 @@ func openInitializeChainDb(ctx context.Context, stack *node.Node, config *NodeCo

var initDataReader statetransfer.InitDataReader = nil

chainDb, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false)
chainData, err := stack.OpenDatabaseWithFreezer("l2chaindata", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, config.Persistent.Ancient, "l2chaindata/", false)
if err != nil {
return chainDb, nil, err
return nil, nil, err
}
wasmDb, err := stack.OpenDatabase("wasm", config.Execution.Caching.DatabaseCache, config.Persistent.Handles, "wasm/", false)
if err != nil {
return nil, nil, err
}
chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmDb)

if config.Init.ImportFile != "" {
initDataReader, err = statetransfer.NewJsonInitDataReader(config.Init.ImportFile)
Expand Down
12 changes: 10 additions & 2 deletions system_tests/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
Expand Down Expand Up @@ -772,8 +773,11 @@ func createL2BlockChainWithStackConfig(
stack, err = node.New(stackConfig)
Require(t, err)

chainDb, err := stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false)
chainData, err := stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false)
Require(t, err)
wasmData, err := stack.OpenDatabase("wasm", 0, 0, "wasm/", false)
Require(t, err)
chainDb := rawdb.WrapDatabaseWithWasm(chainData, wasmData)
arbDb, err := stack.OpenDatabase("arbitrumdata", 0, 0, "arbitrumdata/", false)
Require(t, err)

Expand Down Expand Up @@ -976,8 +980,12 @@ func Create2ndNodeWithConfig(
l2stack, err := node.New(stackConfig)
Require(t, err)

l2chainDb, err := l2stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false)
l2chainData, err := l2stack.OpenDatabase("l2chaindata", 0, 0, "l2chaindata/", false)
Require(t, err)
wasmData, err := l2stack.OpenDatabase("wasm", 0, 0, "wasm/", false)
Require(t, err)
l2chainDb := rawdb.WrapDatabaseWithWasm(l2chainData, wasmData)

l2arbDb, err := l2stack.OpenDatabase("arbitrumdata", 0, 0, "arbitrumdata/", false)
Require(t, err)
initReader := statetransfer.NewMemoryInitDataReader(l2InitData)
Expand Down
78 changes: 78 additions & 0 deletions system_tests/program_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1458,3 +1458,81 @@ func formatTime(duration time.Duration) string {
}
return fmt.Sprintf("%.2f%s", span, units[unit])
}

func TestWasmRecreate(t *testing.T) {
builder, auth, cleanup := setupProgramTest(t, true)
ctx := builder.ctx
l2info := builder.L2Info
l2client := builder.L2.Client
defer cleanup()

storage := deployWasm(t, ctx, auth, l2client, rustFile("storage"))

zero := common.Hash{}
val := common.HexToHash("0x121233445566")

// do an onchain call - store value
storeTx := l2info.PrepareTxTo("Owner", &storage, l2info.TransferGas, nil, argsForStorageWrite(zero, val))
Require(t, l2client.SendTransaction(ctx, storeTx))
_, err := EnsureTxSucceeded(ctx, l2client, storeTx)
Require(t, err)

testDir := t.TempDir()
nodeBStack := createStackConfigForTest(testDir)
nodeB, cleanupB := builder.Build2ndNode(t, &SecondNodeParams{stackConfig: nodeBStack})

_, err = EnsureTxSucceeded(ctx, nodeB.Client, storeTx)
Require(t, err)

// make sure reading 2nd value succeeds from 2nd node
loadTx := l2info.PrepareTxTo("Owner", &storage, l2info.TransferGas, nil, argsForStorageRead(zero))
result, err := arbutil.SendTxAsCall(ctx, nodeB.Client, loadTx, l2info.GetAddress("Owner"), nil, true)
Require(t, err)
if common.BytesToHash(result) != val {
Fatal(t, "got wrong value")
}
// close nodeB
cleanupB()

// delete wasm dir of nodeB

wasmPath := filepath.Join(testDir, "system_tests.test", "wasm")
dirContents, err := os.ReadDir(wasmPath)
Require(t, err)
if len(dirContents) == 0 {
Fatal(t, "not contents found before delete")
}
os.RemoveAll(wasmPath)

// recreate nodeB - using same source dir (wasm deleted)
nodeB, cleanupB = builder.Build2ndNode(t, &SecondNodeParams{stackConfig: nodeBStack})

// test nodeB - sees existing transaction
_, err = EnsureTxSucceeded(ctx, nodeB.Client, storeTx)
Require(t, err)

// test nodeB - answers eth_call (requires reloading wasm)
result, err = arbutil.SendTxAsCall(ctx, nodeB.Client, loadTx, l2info.GetAddress("Owner"), nil, true)
Require(t, err)
if common.BytesToHash(result) != val {
Fatal(t, "got wrong value")
}

// send new tx (requires wasm) and check nodeB sees it as well
Require(t, l2client.SendTransaction(ctx, loadTx))

_, err = EnsureTxSucceeded(ctx, l2client, loadTx)
Require(t, err)

_, err = EnsureTxSucceeded(ctx, nodeB.Client, loadTx)
Require(t, err)

cleanupB()
dirContents, err = os.ReadDir(wasmPath)
Require(t, err)
if len(dirContents) == 0 {
Fatal(t, "not contents found before delete")
}
os.RemoveAll(wasmPath)

}

0 comments on commit 6f73839

Please sign in to comment.