Skip to content

Commit

Permalink
address pr review
Browse files Browse the repository at this point in the history
  • Loading branch information
enitrat committed May 2, 2024
1 parent aa75748 commit f763073
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 40 deletions.
11 changes: 1 addition & 10 deletions src/kakarot/account.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -530,16 +530,7 @@ namespace Account {
return starknet_account_exists;
}

// @notice Return the valid jumpdests related to this account's code, cached during the execution of the transaction.
// @param self The pointer to the Account
// @return The pointer to the dictionary of the valid jumpdests
func cached_jumpdests(self: model.Account*) -> (
valid_jumpdests_start: DictAccess*, valid_jumpdests: DictAccess*
) {
return (self.valid_jumpdests_start, self.valid_jumpdests);
}

func cache_valid_jumpdests{range_check_ptr}(
func set_valid_jumpdests{range_check_ptr}(
self: model.Account*, valid_jumpdests_start: DictAccess*, valid_jumpdests: DictAccess*
) -> model.Account* {
let (copy_start, copy) = default_dict_copy(valid_jumpdests_start, valid_jumpdests);
Expand Down
2 changes: 1 addition & 1 deletion src/kakarot/accounts/account_contract.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func write_jumpdests{
// @param index The index of the jumpdest.
// @return is_valid 1 if the jumpdest is valid, 0 otherwise.
@view
func is_jumpdest_valid{
func is_valid_jumpdest{
syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*
}(index: felt) -> (is_valid: felt) {
let is_valid = AccountContract.is_valid_jumpdest(index);
Expand Down
14 changes: 5 additions & 9 deletions src/kakarot/evm.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ namespace EVM {

let valid_jumpdests = self.message.valid_jumpdests;
with valid_jumpdests {
let is_valid_jumpdest = Internals.is_jumpdest_valid(
let is_valid_jumpdest = Internals.is_valid_jumpdest(
self.message.code_address, self.message.is_create, new_pc_offset
);
}
Expand Down Expand Up @@ -289,7 +289,7 @@ namespace EVM {
}

namespace Internals {
func is_jumpdest_valid{
func is_valid_jumpdest{
syscall_ptr: felt*,
pedersen_ptr: HashBuiltin*,
range_check_ptr,
Expand All @@ -306,13 +306,9 @@ namespace Internals {
}

let code_starknet_address = Account.compute_starknet_address(code_address);
let (is_valid) = IAccount.is_jumpdest_valid(code_starknet_address, index);
let (is_valid) = IAccount.is_valid_jumpdest(code_starknet_address, index);
dict_write{dict_ptr=valid_jumpdests}(index, is_valid);

if (is_valid == 0) {
return FALSE;
}

dict_write{dict_ptr=valid_jumpdests}(index, TRUE);
return TRUE;
return is_valid;
}
}
12 changes: 5 additions & 7 deletions src/kakarot/instructions/system_operations.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1177,9 +1177,6 @@ namespace CallHelper {
let stack = Stack.init();
let memory = Memory.init();

// Use the cached jumpdests from previous calls
let (valid_jumpdests_start, valid_jumpdests) = Account.cached_jumpdests(code_account);

if (is_staticcall != FALSE) {
tempvar read_only = TRUE;
} else {
Expand All @@ -1190,8 +1187,9 @@ namespace CallHelper {
tempvar message = new model.Message(
bytecode=code,
bytecode_len=code_len,
valid_jumpdests_start=valid_jumpdests_start,
valid_jumpdests=valid_jumpdests,
// Use the cached jumpdests from previous calls
valid_jumpdests_start=code_account.valid_jumpdests_start,
valid_jumpdests=code_account.valid_jumpdests,
calldata=calldata,
calldata_len=args_size.low,
value=value,
Expand Down Expand Up @@ -1245,7 +1243,7 @@ namespace CallHelper {

// Write the valid jumpdests cached during the call in the state
let code_account = State.get_account(evm.message.code_address);
let code_account = Account.cache_valid_jumpdests(
let code_account = Account.set_valid_jumpdests(
code_account, evm.message.valid_jumpdests_start, evm.message.valid_jumpdests
);
State.update_account(code_account);
Expand Down Expand Up @@ -1515,7 +1513,7 @@ namespace CreateHelper {
let (valid_jumpdests_start, valid_jumpdests) = Helpers.initialize_jumpdests(
account.code_len, account.code
);
let account = Account.cache_valid_jumpdests(
let account = Account.set_valid_jumpdests(
account, valid_jumpdests_start, valid_jumpdests
);

Expand Down
2 changes: 1 addition & 1 deletion src/kakarot/interfaces/interfaces.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ namespace IAccount {
func set_nonce(nonce: felt) {
}

func is_jumpdest_valid(index: felt) -> (is_valid: felt) {
func is_valid_jumpdest(index: felt) -> (is_valid: felt) {
}

func write_jumpdests(jumpdests_len: felt, jumpdests: felt*) {
Expand Down
6 changes: 2 additions & 4 deletions src/kakarot/interpreter.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1040,17 +1040,15 @@ namespace Internals {
return evm;
}

// Write bytecode and cached the final code valid jumpdests to Account
// Write bytecode and cache the final code valid jumpdests to Account
let account = State.get_account(evm.message.address.evm);
let account = Account.set_code(account, evm.return_data_len, evm.return_data);
let account = Account.set_created(account, TRUE);

let (valid_jumpdests_start, valid_jumpdests) = Helpers.initialize_jumpdests(
account.code_len, account.code
);
let account = Account.cache_valid_jumpdests(
account, valid_jumpdests_start, valid_jumpdests
);
let account = Account.set_valid_jumpdests(account, valid_jumpdests_start, valid_jumpdests);

State.update_account(account);
State.finalize();
Expand Down
2 changes: 1 addition & 1 deletion tests/end_to_end/test_kakarot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def class_hashes():


@pytest_asyncio.fixture(scope="session")
async def origin(evm: Contract, addresses):
async def origin(evm: Contract):
"""
Deploys the origin's Starknet contract to the correct address.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/src/backend/test_starknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_should_save_jumpdests_to_storage(self, cairo_run):
)

assert valid_indexes == [
valid_jumpdests[i][0] for i in range(len(valid_jumpdests))
valid_jumpdest[0] for valid_jumpdest in valid_jumpdests
]

SyscallHandler.mock_call.assert_any_call(
Expand Down
4 changes: 2 additions & 2 deletions tests/src/kakarot/test_evm.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func test__jump{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}
return evm;
}

func test__is_jumpdest_valid{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
func test__is_valid_jumpdest{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
) -> felt {
alloc_locals;

Expand All @@ -42,7 +42,7 @@ func test__is_jumpdest_valid{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, ran
);

with valid_jumpdests {
let result = Internals.is_jumpdest_valid(0, 0, index);
let result = Internals.is_valid_jumpdest(0, 0, index);
}

return result;
Expand Down
8 changes: 4 additions & 4 deletions tests/src/kakarot/test_evm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TestExecutionContext:
def test_jump(self, cairo_run, bytecode, jumpdest, new_pc, expected_return_data):

with SyscallHandler.patch(
"IAccount.is_jumpdest_valid",
"IAccount.is_valid_jumpdest",
lambda addr, data: [1 if len(expected_return_data) == 0 else 0],
):
evm = cairo_run("test__jump", bytecode=bytecode, jumpdest=jumpdest)
Expand All @@ -43,15 +43,15 @@ def test_should_return_cached_valid_jumpdest(
):
assert (
cairo_run(
"test__is_jumpdest_valid",
"test__is_valid_jumpdest",
cached_jumpdests=cached_jumpdests,
index=index,
)
== expected
)

@SyscallHandler.patch(
"IAccount.is_jumpdest_valid",
"IAccount.is_valid_jumpdest",
lambda addr, data: [1 if data == [0x10] else 0],
)
@pytest.mark.parametrize(
Expand All @@ -66,7 +66,7 @@ def test_should_return_non_cached_valid_jumpdest(
):
assert (
cairo_run(
"test__is_jumpdest_valid",
"test__is_valid_jumpdest",
cached_jumpdests=cached_jumpdests,
index=index,
)
Expand Down

0 comments on commit f763073

Please sign in to comment.