Skip to content

Commit

Permalink
Implement the aggregation hook
Browse files Browse the repository at this point in the history
  • Loading branch information
remybar committed Jun 28, 2024
1 parent 6ef5aa2 commit b479967
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 4 deletions.
109 changes: 109 additions & 0 deletions contracts/src/contracts/hooks/aggregation_hook.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#[starknet::contract]
pub mod aggregation_hook {
use alexandria_bytes::{Bytes, BytesTrait};
use hyperlane_starknet::contracts::hooks::libs::standard_hook_metadata::standard_hook_metadata::{
StandardHookMetadata, VARIANT,
};
use hyperlane_starknet::contracts::libs::message::Message;
use hyperlane_starknet::interfaces::{
IPostDispatchHook, Types, IPostDispatchHookDispatcher, IPostDispatchHookDispatcherTrait, ETH_ADDRESS
};
use starknet::{ContractAddress, get_contract_address};
use openzeppelin::token::erc20::interface::{IERC20, IERC20Dispatcher, IERC20DispatcherTrait};

#[storage]
struct Storage {
hooks: LegacyMap::<usize, ContractAddress>,
hook_count: usize,
}

mod Errors {
pub const INVALID_METADATA_VARIANT: felt252 = 'Invalid metadata variant';
pub const INSUFFICIENT_BALANCE: felt252 = 'Insufficient balance';
pub const INSUFFICIENT_FUNDS: felt252 = 'Insufficient funds';
}

#[constructor]
fn constructor(ref self: ContractState, hooks: Span<ContractAddress>) {
let mut i = 0;
loop {
if i >= hooks.len() {
break;
}

self.hooks.write(i, *hooks.at(i));
i += 1;
};

self.hook_count.write(hooks.len());
}

#[abi(embed_v0)]
impl IPostDispatchHookImpl of IPostDispatchHook<ContractState> {
fn hook_type(self: @ContractState) -> Types {
Types::AGGREGATION(())
}

fn supports_metadata(self: @ContractState, _metadata: Bytes) -> bool {
_metadata.size() == 0 || StandardHookMetadata::variant(_metadata) == VARIANT.into()
}

fn post_dispatch(
ref self: ContractState, _metadata: Bytes, _message: Message, _fee_amount: u256
) {
assert(self.supports_metadata(_metadata.clone()), Errors::INVALID_METADATA_VARIANT);

let token_dispatcher = IERC20Dispatcher { contract_address: ETH_ADDRESS() };
let agg_hook_address = get_contract_address();

let balance = token_dispatcher.balance_of(agg_hook_address);
assert(balance >= _fee_amount, Errors::INSUFFICIENT_BALANCE);

let hook_count = self.hook_count.read();
let mut remaining_fees = _fee_amount;
let mut i = 0_usize;
loop {
if i >= hook_count {
break;
}

let hook_address = self.hooks.read(i);
let hook_dispatcher = IPostDispatchHookDispatcher { contract_address: hook_address };

let quote = hook_dispatcher.quote_dispatch(_metadata.clone(), _message.clone());
assert(quote <= remaining_fees, Errors::INSUFFICIENT_FUNDS);

token_dispatcher.transfer(hook_address, quote);
remaining_fees -= quote;

IPostDispatchHookDispatcher { contract_address: hook_address }
.post_dispatch(_metadata.clone(), _message.clone(), quote);

i += 1;
};
}

fn quote_dispatch(ref self: ContractState, _metadata: Bytes, _message: Message) -> u256 {
assert(self.supports_metadata(_metadata.clone()), Errors::INVALID_METADATA_VARIANT);

let hook_count = self.hook_count.read();
let mut i = 0_usize;
let mut total = 0_u256;
loop {
if i >= hook_count {
break;
}

let contract_address = self.hooks.read(i);

let value = IPostDispatchHookDispatcher { contract_address }
.quote_dispatch(_metadata.clone(), _message.clone());

total += value;
i += 1;
};

total
}
}
}
36 changes: 33 additions & 3 deletions contracts/src/contracts/mocks/hook.cairo
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
#[starknet::interface]
pub trait IMockHook<T> {
fn set_quote_dispatch(ref self: T, _value: u256);
fn get_post_dispatch_calls(self: @T) -> u8;
fn get_quote_dispatch_calls(self: @T) -> u8;
}

#[starknet::contract]
pub mod hook {
use alexandria_bytes::{Bytes, BytesTrait, BytesStore};
use hyperlane_starknet::contracts::libs::message::Message;
use hyperlane_starknet::interfaces::{
IPostDispatchHook, IPostDispatchHookDispatcher, IPostDispatchHookDispatcherTrait, Types
};
use super::IMockHook;

#[storage]
struct Storage {}
struct Storage {
quote_value: u256,
post_dispatch_calls: u8,
quote_dispatch_calls: u8,
}

#[abi(embed_v0)]
impl IPostDispatchHookImpl of IPostDispatchHook<ContractState> {
Expand All @@ -21,10 +33,28 @@ pub mod hook {

fn post_dispatch(
ref self: ContractState, _metadata: Bytes, _message: Message, _fee_amount: u256
) {}
) {
self.post_dispatch_calls.write(self.post_dispatch_calls.read() + 1);
}

fn quote_dispatch(ref self: ContractState, _metadata: Bytes, _message: Message) -> u256 {
0_u256
self.quote_dispatch_calls.write(self.quote_dispatch_calls.read() + 1);
self.quote_value.read()
}
}

#[abi(embed_v0)]
impl IMockHookImpl of IMockHook<ContractState> {
fn set_quote_dispatch(ref self: ContractState, _value: u256) {
self.quote_value.write(_value);
}

fn get_post_dispatch_calls(self: @ContractState) -> u8 {
self.post_dispatch_calls.read()
}

fn get_quote_dispatch_calls(self: @ContractState) -> u8 {
self.quote_dispatch_calls.read()
}
}
}
1 change: 0 additions & 1 deletion contracts/src/interfaces.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ pub trait IProtocolFee<TContractState> {
fn collect_protocol_fees(ref self: TContractState);
}


#[starknet::interface]
pub trait IRoutingIsm<TContractState> {
fn route(self: @TContractState, _message: Message) -> ContractAddress;
Expand Down
2 changes: 2 additions & 0 deletions contracts/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod contracts {
}
}
pub mod hooks {
pub mod aggregation_hook;
pub mod merkle_tree_hook;
pub mod protocol_fee;
pub mod libs {
Expand Down Expand Up @@ -64,6 +65,7 @@ mod tests {
pub mod test_messageid_multisig;
}
pub mod hooks {
pub mod test_aggregation_hook;
pub mod test_merkle_tree_hook;
pub mod test_protocol_fee;
}
Expand Down
142 changes: 142 additions & 0 deletions contracts/src/tests/hooks/test_aggregation_hook.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use alexandria_bytes::{Bytes, BytesTrait};
use openzeppelin::token::erc20::interface::{IERC20, IERC20Dispatcher, IERC20DispatcherTrait};
use openzeppelin::access::ownable::interface::{IOwnableDispatcher, IOwnableDispatcherTrait};
use hyperlane_starknet::contracts::libs::message::{Message, MessageTrait};
use hyperlane_starknet::contracts::mocks::hook::{IMockHookDispatcher, IMockHookDispatcherTrait};
use hyperlane_starknet::interfaces::{
Types, IPostDispatchHookDispatcher, IPostDispatchHookDispatcherTrait,
};
use hyperlane_starknet::tests::setup::{setup_mock_token, setup_aggregation_hook, OWNER};
use snforge_std::{declare, ContractClassTrait, start_prank, CheatTarget, stop_prank};
use starknet::{ContractAddress};

fn _build_metadata() -> Bytes {
let mut metadata = BytesTrait::new_empty();
let variant = 1;
metadata.append_u16(variant);
metadata
}

fn _build_hook_list(quotes: @Array<u256>) -> Span<ContractAddress> {
let mut hooks = array![];
let mock_hook = declare("hook").unwrap();

let mut i = 0;
loop {
if i >= quotes.len() {
break;
}

let (contract_address, _) = mock_hook.deploy(@array![]).unwrap();
IMockHookDispatcher { contract_address }.set_quote_dispatch(*quotes.at(i));

hooks.append(contract_address);

i += 1;
};

hooks.span()
}

fn _setup_eth_balance(token_dispatcher: IERC20Dispatcher, recipient: ContractAddress, amount: u256) {
let ownable = IOwnableDispatcher { contract_address: token_dispatcher.contract_address };
start_prank(CheatTarget::One(ownable.contract_address), OWNER());
token_dispatcher.transfer(recipient, amount);
assert_eq!(token_dispatcher.balance_of(recipient), amount);
stop_prank(CheatTarget::One(ownable.contract_address));
}

#[test]
fn test_hook_type() {
setup_mock_token();
let hooks = _build_hook_list(@array![100_u256, 200_u256, 300_u256]);
let post_dispatch_hook = setup_aggregation_hook(@hooks);

assert_eq!(post_dispatch_hook.hook_type(), Types::AGGREGATION(()));
}

#[test]
fn test_aggregate_quote_dispatch() {
// arrange
setup_mock_token();
let hooks = _build_hook_list(@array![100_u256, 200_u256, 300_u256]);
let post_dispatch_hook = setup_aggregation_hook(@hooks);

let expected_quote = 600_u256;

// act
let quote = post_dispatch_hook.quote_dispatch(_build_metadata(), MessageTrait::default());

// assert
assert_eq!(quote, expected_quote);

let mut i = 0;
loop {
if i >= hooks.len() {
break;
}

let contract_address = *hooks.at(i);
assert_eq!(IMockHookDispatcher { contract_address }.get_quote_dispatch_calls(), 1);

i += 1;
}
}

#[test]
fn test_aggregate_post_dispatch() {
// arrange
let token_dispatcher = setup_mock_token();

let hooks = _build_hook_list(@array![100_u256, 200_u256, 300_u256]);
let fee_amount = 600_u256;
let post_dispatch_hook = setup_aggregation_hook(@hooks);

_setup_eth_balance(token_dispatcher, post_dispatch_hook.contract_address, fee_amount);

// act
post_dispatch_hook.post_dispatch(_build_metadata(), MessageTrait::default(), fee_amount);

// assert
let mut i = 0;
loop {
if i >= hooks.len() {
break;
}

let contract_address = *hooks.at(i);
assert_eq!(IMockHookDispatcher { contract_address }.get_post_dispatch_calls(), 1);

i += 1;
}
}

#[test]
#[should_panic(expected: ('Insufficient balance',))]
fn test_aggregate_post_dispatch_insufficient_balance() {
// arrange
setup_mock_token();

let hooks = _build_hook_list(@array![100_u256, 200_u256, 300_u256]);
let fee_amount = 600_u256;
let post_dispatch_hook = setup_aggregation_hook(@hooks);

// act
post_dispatch_hook.post_dispatch(_build_metadata(), MessageTrait::default(), fee_amount);
}

#[test]
#[should_panic(expected: ('Insufficient funds',))]
fn test_aggregate_post_dispatch_insufficient_funds() {
// arrange
let token_dispatcher = setup_mock_token();

let hooks = _build_hook_list(@array![100_u256, 200_u256, 300_u256]);
let fee_amount = 599_u256;
let post_dispatch_hook = setup_aggregation_hook(@hooks);

_setup_eth_balance(token_dispatcher, post_dispatch_hook.contract_address, fee_amount);

// act
post_dispatch_hook.post_dispatch(_build_metadata(), MessageTrait::default(), fee_amount);
}
9 changes: 9 additions & 0 deletions contracts/src/tests/setup.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,12 @@ pub fn setup_protocol_fee() -> (IProtocolFeeDispatcher, IPostDispatchHookDispatc
IPostDispatchHookDispatcher { contract_address: protocol_fee_addr }
)
}

pub fn setup_aggregation_hook(hooks: @Span<ContractAddress>) -> IPostDispatchHookDispatcher {
let aggregation_class = declare("aggregation_hook").unwrap();
let mut ctor_data = array![];
hooks.serialize(ref ctor_data);
let (aggregation_addr, _) = aggregation_class.deploy(@ctor_data).unwrap();

IPostDispatchHookDispatcher { contract_address: aggregation_addr }
}

0 comments on commit b479967

Please sign in to comment.