Skip to content

Commit

Permalink
feat: upgrade contracts invoked if outdated (#1153)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

<!-- Give an estimate of the time you spent on this PR in terms of work
days.
Did you spend 0.5 days on this PR or rather 2 days?  -->

Time spent on this PR: 0.3d

## Pull request type

<!-- Please try to limit your pull request to one type,
submit multiple pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying,
or link to a relevant issue. -->

Resolves #<Issue number>

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

-
-
-

<!-- Reviewable:start -->
- - -
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/1153)
<!-- Reviewable:end -->
  • Loading branch information
enitrat authored May 23, 2024
1 parent 3aacc3a commit c276e4c
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 6 deletions.
24 changes: 20 additions & 4 deletions src/kakarot/account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,16 @@ namespace Account {
address=address, code_len=0, code=bytecode, nonce=0, balance=balance_ptr
);
return account;
} else {
tempvar address = new model.Address(starknet=starknet_address, evm=evm_address);
let balance = fetch_balance(address);
assert balance_ptr = new Uint256(balance.low, balance.high);
}

tempvar address = new model.Address(starknet=starknet_address, evm=evm_address);
let balance = fetch_balance(address);
assert balance_ptr = new Uint256(balance.low, balance.high);

// Upgrade the target starknet contract's class if it's not the latest one.
// The contract must be deployed on starknet already.
Internals.check_and_upgrade_account_class(address);

let (bytecode_len, bytecode) = IAccount.bytecode(contract_address=starknet_address);
let (nonce) = IAccount.get_nonce(contract_address=starknet_address);

Expand Down Expand Up @@ -671,4 +675,16 @@ namespace Internals {

return _cache_storage_keys(evm_address, storage_keys_len - 1, storage_keys + Uint256.SIZE);
}

func check_and_upgrade_account_class{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr
}(address: model.Address*) {
let (account_impl) = IAccount.get_implementation(address.starknet);
let (latest_impl) = Kakarot_account_contract_class_hash.read();
if (account_impl == latest_impl) {
return ();
}
IAccount.set_implementation(address.starknet, latest_impl);
return ();
}
}
14 changes: 14 additions & 0 deletions src/kakarot/accounts/account_contract.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ func get_evm_address{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check
return AccountContract.get_evm_address();
}

@view
func get_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
implementation: felt
) {
return AccountContract.get_implementation();
}

@external
func set_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
implementation_class: felt
) {
return AccountContract.set_implementation(implementation_class);
}

// @notice Checks if the account was initialized.
// @return is_initialized: 1 if the account has been initialized 0 otherwise.
@view
Expand Down
10 changes: 10 additions & 0 deletions src/kakarot/accounts/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,16 @@ namespace AccountContract {
return (implementation=implementation);
}

func set_implementation{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
new_implementation: felt
) {
// Access control check.
Ownable.assert_only_owner();
replace_class(new_implementation);
Account_implementation.write(new_implementation);
return ();
}

// @return address The EVM address of the account
func get_evm_address{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}() -> (
address: felt
Expand Down
1 change: 1 addition & 0 deletions src/kakarot/instructions/system_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,7 @@ namespace CallHelper {
Memory.load_n(args_size.low, calldata, args_offset.low);

// 2. Build child_evm

let code_account = State.get_account(code_address);
local code_len: felt = code_account.code_len;
local code: felt* = code_account.code;
Expand Down
6 changes: 6 additions & 0 deletions src/kakarot/interfaces/interfaces.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ namespace IAccount {
func get_evm_address() -> (evm_address: felt) {
}

func get_implementation() -> (implementation: felt) {
}

func set_implementation(implementation: felt) {
}

func version() -> (version: felt) {
}

Expand Down
93 changes: 92 additions & 1 deletion tests/end_to_end/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import pytest_asyncio
from eth_utils import keccak
from starknet_py.net.full_node_client import FullNodeClient
from starkware.starknet.public.abi import get_storage_var_address

Expand Down Expand Up @@ -41,7 +42,7 @@ async def new_account(max_fee):
return account


@pytest_asyncio.fixture(scope="module")
@pytest_asyncio.fixture(scope="function")
async def counter(deploy_contract, new_account):
return await deploy_contract(
"PlainOpcodes",
Expand All @@ -50,9 +51,19 @@ async def counter(deploy_contract, new_account):
)


@pytest_asyncio.fixture(scope="module")
async def caller(deploy_contract, owner):
return await deploy_contract(
"PlainOpcodes",
"Caller",
caller_eoa=owner.starknet_contract,
)


@pytest.fixture(autouse=True)
async def cleanup(invoke, class_hashes):
yield

await invoke(
"kakarot",
"set_account_contract_class_hash",
Expand All @@ -72,6 +83,18 @@ async def assert_counter_transaction_success(counter, new_account):
assert await counter.count() == prev_count + 1


async def assert_caller_contract_increases_counter(caller, counter, new_account):
"""
Assert that the transaction sent, other than upgrading the account contract, is successful.
"""
prev_count = await counter.count()
inc_selector = keccak(b"inc()")[0:4]
await caller.call(
counter.address, inc_selector, caller_eoa=new_account.starknet_contract
)
assert await counter.count() == prev_count + 1


@pytest.mark.asyncio(scope="session")
@pytest.mark.AccountContract
class TestAccount:
Expand All @@ -83,6 +106,7 @@ async def test_should_upgrade_outdated_account_on_transfer(
counter,
new_account,
class_hashes,
cleanup,
):
prev_class = await starknet.get_class_hash_at(
new_account.starknet_contract.address
Expand Down Expand Up @@ -135,3 +159,70 @@ async def test_should_update_cairo1_helpers_class(
)
== target_class
)

class TestAutoUpgradeContracts:
async def test_should_upgrade_outdated_contract_transaction_target(
self,
starknet: FullNodeClient,
invoke,
call,
counter,
new_account,
class_hashes,
):
counter_starknet_address = (
await call(
"kakarot",
"get_starknet_address",
int(counter.address, 16),
)
).starknet_address
prev_class = await starknet.get_class_hash_at(counter_starknet_address)
target_class = class_hashes["account_contract_fixture"]
assert prev_class != target_class
assert prev_class == class_hashes["account_contract"]

await invoke(
"kakarot",
"set_account_contract_class_hash",
target_class,
)

await assert_counter_transaction_success(counter, new_account)

new_class = await starknet.get_class_hash_at(counter_starknet_address)
assert new_class == target_class

async def test_should_upgrade_outdated_contract_called_contract(
self,
starknet: FullNodeClient,
invoke,
counter,
call,
caller,
new_account,
class_hashes,
cleanup,
):
counter_starknet_address = (
await call(
"kakarot",
"get_starknet_address",
int(counter.address, 16),
)
).starknet_address
prev_class = await starknet.get_class_hash_at(counter_starknet_address)
target_class = class_hashes["account_contract_fixture"]
assert prev_class != target_class
assert prev_class == class_hashes["account_contract"]

await invoke(
"kakarot",
"set_account_contract_class_hash",
target_class,
)

await assert_caller_contract_increases_counter(caller, counter, new_account)

new_class = await starknet.get_class_hash_at(counter_starknet_address)
assert new_class == target_class
4 changes: 4 additions & 0 deletions tests/fixtures/account_contract_fixture.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ from kakarot.accounts.account_contract import (
storage,
get_nonce,
set_nonce,
get_implementation,
set_implementation,
is_valid_jumpdest,
write_jumpdests,
)

// make sure the class hash is different
Expand Down
2 changes: 2 additions & 0 deletions tests/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

ZERO_ADDRESS = "0x" + 40 * "0"

ACCOUNT_CLASS_IMPLEMENTATION = 0xC0DEC1A55

BLOCK_NUMBER = 0x42
BLOCK_TIMESTAMP = int(time())

Expand Down
6 changes: 5 additions & 1 deletion tests/utils/syscall_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_storage_var_address,
)

from tests.utils.constants import CHAIN_ID
from tests.utils.constants import ACCOUNT_CLASS_IMPLEMENTATION, CHAIN_ID
from tests.utils.uint256 import int_to_uint256, uint256_to_int


Expand Down Expand Up @@ -173,6 +173,10 @@ class SyscallHandler:
get_selector_from_name(
"verify_signature_secp256r1"
): cairo_verify_signature_secp256r1,
get_selector_from_name("get_implementation"): lambda addr, data: [
ACCOUNT_CLASS_IMPLEMENTATION
],
get_selector_from_name("set_implementation"): lambda addr, data: [],
}

def get_contract_address(self, segments, syscall_ptr):
Expand Down

0 comments on commit c276e4c

Please sign in to comment.