From 328270908288d70116e692224ba91ba7b3da4924 Mon Sep 17 00:00:00 2001 From: Haythem Sellami <17862704+haythemsellami@users.noreply.github.com> Date: Tue, 18 Jun 2024 12:40:42 +0300 Subject: [PATCH] bytecode under limit --- src/Dispatch.sol | 6 +- src/FourSixTwoSixAgg.sol | 86 +++------------ src/FourSixTwoSixAggFactory.sol | 41 +++---- src/Shared.sol | 96 +++++++++++++++++ src/lib/ErrorsLib.sol | 4 + src/modules/AllocationPoints.sol | 109 +++++++++++++++++++ src/modules/Fee.sol | 51 +++++++++ src/modules/Hooks.sol | 29 +++++ src/modules/Rewards.sol | 144 +++++++++++++++++++++++++ test/common/FourSixTwoSixAggBase.t.sol | 4 + test/e2e/BalanceForwarderE2ETest.t.sol | 1 + test/e2e/HooksE2ETest.t.sol | 8 +- 12 files changed, 485 insertions(+), 94 deletions(-) create mode 100644 src/Shared.sol create mode 100644 src/modules/AllocationPoints.sol create mode 100644 src/modules/Fee.sol create mode 100644 src/modules/Hooks.sol create mode 100644 src/modules/Rewards.sol diff --git a/src/Dispatch.sol b/src/Dispatch.sol index 60e96a49..16a75f16 100644 --- a/src/Dispatch.sol +++ b/src/Dispatch.sol @@ -12,11 +12,15 @@ abstract contract Dispatch is RewardsModule, HooksModule { address public immutable MODULE_REWARDS; address public immutable MODULE_HOOKS; address public immutable MODULE_FEE; + address public immutable MODULE_ALLOCATION_POINTS; - constructor(address _rewardsModule, address _hooksModule, address _feeModule) Shared() { + constructor(address _rewardsModule, address _hooksModule, address _feeModule, address _allocationPointsModule) + Shared() + { MODULE_REWARDS = _rewardsModule; MODULE_HOOKS = _hooksModule; MODULE_FEE = _feeModule; + MODULE_ALLOCATION_POINTS = _allocationPointsModule; } // Modifier proxies the function call to a module and low-level returns the result diff --git a/src/FourSixTwoSixAgg.sol b/src/FourSixTwoSixAgg.sol index ee506d51..3d518e81 100644 --- a/src/FourSixTwoSixAgg.sol +++ b/src/FourSixTwoSixAgg.sol @@ -54,8 +54,8 @@ contract FourSixTwoSixAgg is ERC4626Upgradeable, AccessControlEnumerableUpgradea event SetStrategyCap(address indexed strategy, uint256 cap); event Rebalance(address indexed strategy, uint256 _amountToRebalance, bool _isDeposit); - constructor(address _rewardsModule, address _hooksModule, address _feeModule) - Dispatch(_rewardsModule, _hooksModule, _feeModule) + constructor(address _rewardsModule, address _hooksModule, address _feeModule, address _allocationPointsModule) + Dispatch(_rewardsModule, _hooksModule, _feeModule, _allocationPointsModule) {} struct InitParams { @@ -195,91 +195,35 @@ contract FourSixTwoSixAgg is ERC4626Upgradeable, AccessControlEnumerableUpgradea /// @param _newPoints new strategy's points function adjustAllocationPoints(address _strategy, uint256 _newPoints) external - nonReentrant + use(MODULE_ALLOCATION_POINTS) onlyRole(STRATEGY_MANAGER) - { - AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); - - Strategy memory strategyDataCache = $.strategies[_strategy]; - - if (!strategyDataCache.active) { - revert ErrorsLib.InactiveStrategy(); - } - - $.strategies[_strategy].allocationPoints = _newPoints.toUint120(); - $.totalAllocationPoints = $.totalAllocationPoints + _newPoints - strategyDataCache.allocationPoints; - - emit AdjustAllocationPoints(_strategy, strategyDataCache.allocationPoints, _newPoints); - } + {} /// @notice Set cap on strategy allocated amount. /// @dev By default, cap is set to 0, not activated. /// @param _strategy Strategy address. /// @param _cap Cap amount - function setStrategyCap(address _strategy, uint256 _cap) external nonReentrant onlyRole(STRATEGY_MANAGER) { - AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); - - if (!$.strategies[_strategy].active) { - revert ErrorsLib.InactiveStrategy(); - } - - $.strategies[_strategy].cap = _cap.toUint120(); - - emit SetStrategyCap(_strategy, _cap); - } + function setStrategyCap(address _strategy, uint256 _cap) + external + use(MODULE_ALLOCATION_POINTS) + onlyRole(STRATEGY_MANAGER) + {} /// @notice Add new strategy with it's allocation points. /// @dev Can only be called by an address that have STRATEGY_ADDER. /// @param _strategy Address of the strategy /// @param _allocationPoints Strategy's allocation points - function addStrategy(address _strategy, uint256 _allocationPoints) external nonReentrant onlyRole(STRATEGY_ADDER) { - AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); - - if ($.strategies[_strategy].active) { - revert ErrorsLib.StrategyAlreadyExist(); - } - - if (IERC4626(_strategy).asset() != asset()) { - revert ErrorsLib.InvalidStrategyAsset(); - } - - _callHooksTarget(ADD_STRATEGY, _msgSender()); - - $.strategies[_strategy] = - Strategy({allocated: 0, allocationPoints: _allocationPoints.toUint120(), active: true, cap: 0}); - - $.totalAllocationPoints += _allocationPoints; - IWithdrawalQueue($.withdrawalQueue).addStrategyToWithdrawalQueue(_strategy); - - emit AddStrategy(_strategy, _allocationPoints); - } + function addStrategy(address _strategy, uint256 _allocationPoints) + external + use(MODULE_ALLOCATION_POINTS) + onlyRole(STRATEGY_ADDER) + {} /// @notice Remove strategy and set its allocation points to zero. /// @dev This function does not pull funds, `harvest()` needs to be called to withdraw /// @dev Can only be called by an address that have the STRATEGY_REMOVER /// @param _strategy Address of the strategy - function removeStrategy(address _strategy) external nonReentrant onlyRole(STRATEGY_REMOVER) { - if (_strategy == address(0)) revert ErrorsLib.CanNotRemoveCashReserve(); - - AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); - - Strategy storage strategyStorage = $.strategies[_strategy]; - - if (!strategyStorage.active) { - revert ErrorsLib.AlreadyRemoved(); - } - - _callHooksTarget(REMOVE_STRATEGY, _msgSender()); - - $.totalAllocationPoints -= strategyStorage.allocationPoints; - strategyStorage.active = false; - strategyStorage.allocationPoints = 0; - - // remove from withdrawalQueue - IWithdrawalQueue($.withdrawalQueue).removeStrategyFromWithdrawalQueue(_strategy); - - emit RemoveStrategy(_strategy); - } + function removeStrategy(address _strategy) external use(MODULE_ALLOCATION_POINTS) onlyRole(STRATEGY_REMOVER) {} /// @notice update accrued interest function updateInterestAccrued() external { diff --git a/src/FourSixTwoSixAggFactory.sol b/src/FourSixTwoSixAggFactory.sol index 5e9a6ac5..00465d44 100644 --- a/src/FourSixTwoSixAggFactory.sol +++ b/src/FourSixTwoSixAggFactory.sol @@ -15,32 +15,35 @@ contract FourSixTwoSixAggFactory { address public immutable evc; address public immutable balanceTracker; /// core modules - address public immutable rewardsImpl; - address public immutable hooksImpl; + address public immutable rewardsModuleImpl; + address public immutable hooksModuleImpl; address public immutable feeModuleImpl; + address public immutable allocationpointsModuleImpl; /// peripheries /// @dev Rebalancer periphery, one instance can serve different aggregation vaults - address public immutable rebalancer; + address public immutable rebalancerAddr; /// @dev Withdrawal queue perihperhy, need to be deployed per aggregation vault - address public immutable withdrawalQueueImpl; + address public immutable withdrawalQueueAddr; constructor( address _evc, address _balanceTracker, - address _rewardsImpl, - address _hooksImpl, + address _rewardsModuleImpl, + address _hooksModuleImpl, address _feeModuleImpl, - address _rebalancer, - address _withdrawalQueueImpl + address _allocationpointsModuleImpl, + address _rebalancerAddr, + address _withdrawalQueueAddr ) { evc = _evc; balanceTracker = _balanceTracker; - rewardsImpl = _rewardsImpl; - hooksImpl = _hooksImpl; + rewardsModuleImpl = _rewardsModuleImpl; + hooksModuleImpl = _hooksModuleImpl; feeModuleImpl = _feeModuleImpl; + allocationpointsModuleImpl = _allocationpointsModuleImpl; - rebalancer = _rebalancer; - withdrawalQueueImpl = _withdrawalQueueImpl; + rebalancerAddr = _rebalancerAddr; + withdrawalQueueAddr = _withdrawalQueueAddr; } // TODO: decrease bytecode size, use clones or something @@ -51,21 +54,23 @@ contract FourSixTwoSixAggFactory { uint256 _initialCashAllocationPoints ) external returns (address) { // cloning core modules - address rewardsImplAddr = Clones.clone(rewardsImpl); - address hooks = Clones.clone(hooksImpl); - address feeModule = Clones.clone(feeModuleImpl); + address rewardsModuleAddr = Clones.clone(rewardsModuleImpl); + address hooksModuleAddr = Clones.clone(hooksModuleImpl); + address feeModuleAddr = Clones.clone(feeModuleImpl); + address allocationpointsModuleAddr = Clones.clone(allocationpointsModuleImpl); // cloning peripheries - WithdrawalQueue withdrawalQueue = WithdrawalQueue(Clones.clone(withdrawalQueueImpl)); + WithdrawalQueue withdrawalQueue = WithdrawalQueue(Clones.clone(withdrawalQueueAddr)); // deploy new aggregation vault - FourSixTwoSixAgg fourSixTwoSixAgg = new FourSixTwoSixAgg(rewardsImplAddr, hooks, feeModule); + FourSixTwoSixAgg fourSixTwoSixAgg = + new FourSixTwoSixAgg(rewardsModuleAddr, hooksModuleAddr, feeModuleAddr, allocationpointsModuleAddr); FourSixTwoSixAgg.InitParams memory aggregationVaultInitParams = FourSixTwoSixAgg.InitParams({ evc: evc, balanceTracker: balanceTracker, withdrawalQueuePeriphery: address(withdrawalQueue), - rebalancerPerihpery: rebalancer, + rebalancerPerihpery: rebalancerAddr, aggregationVaultOwner: msg.sender, asset: _asset, name: _name, diff --git a/src/Shared.sol b/src/Shared.sol new file mode 100644 index 00000000..d9e29d19 --- /dev/null +++ b/src/Shared.sol @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.0; + +import {StorageLib, AggregationVaultStorage, HooksStorage} from "./lib/StorageLib.sol"; +import {ErrorsLib} from "./lib/ErrorsLib.sol"; +import {IEVC} from "ethereum-vault-connector/utils/EVCUtil.sol"; +import {IHookTarget} from "evk/src/interfaces/IHookTarget.sol"; +import {HooksLib} from "./lib/HooksLib.sol"; + +contract Shared { + using HooksLib for uint32; + + uint8 internal constant REENTRANCYLOCK__UNLOCKED = 1; + uint8 internal constant REENTRANCYLOCK__LOCKED = 2; + + uint32 public constant DEPOSIT = 1 << 0; + uint32 public constant WITHDRAW = 1 << 1; + uint32 public constant ADD_STRATEGY = 1 << 2; + uint32 public constant REMOVE_STRATEGY = 1 << 3; + + uint32 constant ACTIONS_COUNTER = 1 << 4; + uint256 constant HOOKS_MASK = 0x00000000000000000000000000000000000000000000000000000000FFFFFFFF; + + event SetHooksConfig(address indexed hooksTarget, uint32 hookedFns); + + modifier nonReentrant() { + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + + if ($.locked == REENTRANCYLOCK__LOCKED) revert ErrorsLib.Reentrancy(); + + $.locked = REENTRANCYLOCK__LOCKED; + _; + $.locked = REENTRANCYLOCK__UNLOCKED; + } + + function _msgSender() internal view virtual returns (address) { + address sender = msg.sender; + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + + if (sender == address($.evc)) { + (sender,) = IEVC($.evc).getCurrentOnBehalfOfAccount(address(0)); + } + + return sender; + } + + function _setHooksConfig(address _hooksTarget, uint32 _hookedFns) internal { + if (_hooksTarget != address(0) && IHookTarget(_hooksTarget).isHookTarget() != IHookTarget.isHookTarget.selector) + { + revert ErrorsLib.NotHooksContract(); + } + if (_hookedFns != 0 && _hooksTarget == address(0)) { + revert ErrorsLib.InvalidHooksTarget(); + } + if (_hookedFns >= ACTIONS_COUNTER) revert ErrorsLib.InvalidHookedFns(); + + HooksStorage storage $ = StorageLib._getHooksStorage(); + $.hooksConfig = (uint256(uint160(_hooksTarget)) << 32) | uint256(_hookedFns); + + emit SetHooksConfig(_hooksTarget, _hookedFns); + } + + /// @notice Checks whether a hook has been installed for the function and if so, invokes the hook target. + /// @param _fn Function to call the hook for. + /// @param _caller Caller's address. + function _callHooksTarget(uint32 _fn, address _caller) internal { + HooksStorage storage $ = StorageLib._getHooksStorage(); + + (address target, uint32 hookedFns) = _getHooksConfig($.hooksConfig); + + if (hookedFns.isNotSet(_fn)) return; + + (bool success, bytes memory data) = target.call(abi.encodePacked(msg.data, _caller)); + + if (!success) _revertBytes(data); + } + + /// @notice Get the hooks contract and the hooked functions. + /// @return address Hooks contract. + /// @return uint32 Hooked functions. + function _getHooksConfig(uint256 _hooksConfig) internal pure returns (address, uint32) { + return (address(uint160(_hooksConfig >> 32)), uint32(_hooksConfig & HOOKS_MASK)); + } + + /// @dev Revert with call error or EmptyError + /// @param _errorMsg call revert message + function _revertBytes(bytes memory _errorMsg) private pure { + if (_errorMsg.length > 0) { + assembly { + revert(add(32, _errorMsg), mload(_errorMsg)) + } + } + + revert ErrorsLib.EmptyError(); + } +} diff --git a/src/lib/ErrorsLib.sol b/src/lib/ErrorsLib.sol index 3b57e3b8..d7599b39 100644 --- a/src/lib/ErrorsLib.sol +++ b/src/lib/ErrorsLib.sol @@ -19,4 +19,8 @@ library ErrorsLib { error NotSupported(); error AlreadyEnabled(); error AlreadyDisabled(); + error InvalidHooksTarget(); + error NotHooksContract(); + error InvalidHookedFns(); + error EmptyError(); } diff --git a/src/modules/AllocationPoints.sol b/src/modules/AllocationPoints.sol new file mode 100644 index 00000000..6029741e --- /dev/null +++ b/src/modules/AllocationPoints.sol @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.0; + +import {Shared} from "../Shared.sol"; +// internal dep +import {StorageLib, AggregationVaultStorage, Strategy} from "../lib/StorageLib.sol"; +import {ErrorsLib} from "../lib/ErrorsLib.sol"; +import {IERC4626} from "@openzeppelin-upgradeable/token/ERC20/extensions/ERC4626Upgradeable.sol"; +import {IWithdrawalQueue} from "../interface/IWithdrawalQueue.sol"; +import {SafeCast} from "@openzeppelin/contracts/utils/math/SafeCast.sol"; + +abstract contract AllocationPointsModule is Shared { + using SafeCast for uint256; + + event AdjustAllocationPoints(address indexed strategy, uint256 oldPoints, uint256 newPoints); + event AddStrategy(address indexed strategy, uint256 allocationPoints); + event RemoveStrategy(address indexed _strategy); + event SetStrategyCap(address indexed strategy, uint256 cap); + + /// @notice Adjust a certain strategy's allocation points. + /// @dev Can only be called by an address that have the STRATEGY_MANAGER + /// @param _strategy address of strategy + /// @param _newPoints new strategy's points + function adjustAllocationPoints(address _strategy, uint256 _newPoints) external virtual nonReentrant { + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + + Strategy memory strategyDataCache = $.strategies[_strategy]; + + if (!strategyDataCache.active) { + revert ErrorsLib.InactiveStrategy(); + } + + $.strategies[_strategy].allocationPoints = _newPoints.toUint120(); + $.totalAllocationPoints = $.totalAllocationPoints + _newPoints - strategyDataCache.allocationPoints; + + emit AdjustAllocationPoints(_strategy, strategyDataCache.allocationPoints, _newPoints); + } + + /// @notice Set cap on strategy allocated amount. + /// @dev By default, cap is set to 0, not activated. + /// @param _strategy Strategy address. + /// @param _cap Cap amount + function setStrategyCap(address _strategy, uint256 _cap) external virtual nonReentrant { + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + + if (!$.strategies[_strategy].active) { + revert ErrorsLib.InactiveStrategy(); + } + + $.strategies[_strategy].cap = _cap.toUint120(); + + emit SetStrategyCap(_strategy, _cap); + } + + /// @notice Add new strategy with it's allocation points. + /// @dev Can only be called by an address that have STRATEGY_ADDER. + /// @param _strategy Address of the strategy + /// @param _allocationPoints Strategy's allocation points + function addStrategy(address _strategy, uint256 _allocationPoints) external virtual nonReentrant { + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + + if ($.strategies[_strategy].active) { + revert ErrorsLib.StrategyAlreadyExist(); + } + + if (IERC4626(_strategy).asset() != IERC4626(address(this)).asset()) { + revert ErrorsLib.InvalidStrategyAsset(); + } + + _callHooksTarget(ADD_STRATEGY, _msgSender()); + + $.strategies[_strategy] = + Strategy({allocated: 0, allocationPoints: _allocationPoints.toUint120(), active: true, cap: 0}); + + $.totalAllocationPoints += _allocationPoints; + IWithdrawalQueue($.withdrawalQueue).addStrategyToWithdrawalQueue(_strategy); + + emit AddStrategy(_strategy, _allocationPoints); + } + + /// @notice Remove strategy and set its allocation points to zero. + /// @dev This function does not pull funds, `harvest()` needs to be called to withdraw + /// @dev Can only be called by an address that have the STRATEGY_REMOVER + /// @param _strategy Address of the strategy + function removeStrategy(address _strategy) external virtual nonReentrant { + if (_strategy == address(0)) revert ErrorsLib.CanNotRemoveCashReserve(); + + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + + Strategy storage strategyStorage = $.strategies[_strategy]; + + if (!strategyStorage.active) { + revert ErrorsLib.AlreadyRemoved(); + } + + _callHooksTarget(REMOVE_STRATEGY, _msgSender()); + + $.totalAllocationPoints -= strategyStorage.allocationPoints; + strategyStorage.active = false; + strategyStorage.allocationPoints = 0; + + // remove from withdrawalQueue + IWithdrawalQueue($.withdrawalQueue).removeStrategyFromWithdrawalQueue(_strategy); + + emit RemoveStrategy(_strategy); + } +} + +contract AllocationPoints is AllocationPointsModule {} diff --git a/src/modules/Fee.sol b/src/modules/Fee.sol new file mode 100644 index 00000000..6707efa5 --- /dev/null +++ b/src/modules/Fee.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.0; + +import {Shared} from "../Shared.sol"; +import {StorageLib, BalanceForwarderStorage, AggregationVaultStorage} from "../lib/StorageLib.sol"; +import {IBalanceForwarder} from "../interface/IBalanceForwarder.sol"; +import {IBalanceTracker} from "reward-streams/interfaces/IBalanceTracker.sol"; +import {ErrorsLib} from "../lib/ErrorsLib.sol"; +import {IRewardStreams} from "reward-streams/interfaces/IRewardStreams.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +abstract contract FeeModule is Shared { + using StorageLib for *; + + /// @dev The maximum performanceFee the vault can have is 50% + uint256 internal constant MAX_PERFORMANCE_FEE = 0.5e18; + + event SetFeeRecipient(address indexed oldRecipient, address indexed newRecipient); + event SetPerformanceFee(uint256 oldFee, uint256 newFee); + + /// @notice Set performance fee recipient address + /// @notice @param _newFeeRecipient Recipient address + function setFeeRecipient(address _newFeeRecipient) external { + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + address feeRecipientCached = $.feeRecipient; + + if (_newFeeRecipient == feeRecipientCached) revert ErrorsLib.FeeRecipientAlreadySet(); + + emit SetFeeRecipient(feeRecipientCached, _newFeeRecipient); + + $.feeRecipient = _newFeeRecipient; + } + + /// @notice Set performance fee (1e18 == 100%) + /// @notice @param _newFee Fee rate + function setPerformanceFee(uint256 _newFee) external { + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + + uint256 performanceFeeCached = $.performanceFee; + + if (_newFee > MAX_PERFORMANCE_FEE) revert ErrorsLib.MaxPerformanceFeeExceeded(); + if ($.feeRecipient == address(0)) revert ErrorsLib.FeeRecipientNotSet(); + if (_newFee == performanceFeeCached) revert ErrorsLib.PerformanceFeeAlreadySet(); + + emit SetPerformanceFee(performanceFeeCached, _newFee); + + $.performanceFee = _newFee; + } +} + +contract Fee is FeeModule {} diff --git a/src/modules/Hooks.sol b/src/modules/Hooks.sol new file mode 100644 index 00000000..8fee680b --- /dev/null +++ b/src/modules/Hooks.sol @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.0; + +import {Shared} from "../Shared.sol"; +import {StorageLib, HooksStorage} from "../lib/StorageLib.sol"; +import {HooksLib} from "../lib/HooksLib.sol"; +import {IHookTarget} from "evk/src/interfaces/IHookTarget.sol"; + +abstract contract HooksModule is Shared { + using HooksLib for uint32; + + /// @notice Set hooks contract and hooked functions. + /// @param _hooksTarget Hooks contract. + /// @param _hookedFns Hooked functions. + function setHooksConfig(address _hooksTarget, uint32 _hookedFns) external virtual nonReentrant { + _setHooksConfig(_hooksTarget, _hookedFns); + } + + /// @notice Get the hooks contract and the hooked functions. + /// @return address Hooks contract. + /// @return uint32 Hooked functions. + function getHooksConfig() external view returns (address, uint32) { + HooksStorage storage $ = StorageLib._getHooksStorage(); + + return _getHooksConfig($.hooksConfig); + } +} + +contract Hooks is HooksModule {} diff --git a/src/modules/Rewards.sol b/src/modules/Rewards.sol new file mode 100644 index 00000000..5afb4a72 --- /dev/null +++ b/src/modules/Rewards.sol @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.0; + +import {Shared} from "../Shared.sol"; +import {StorageLib, BalanceForwarderStorage, AggregationVaultStorage} from "../lib/StorageLib.sol"; +import {IBalanceForwarder} from "../interface/IBalanceForwarder.sol"; +import {IBalanceTracker} from "reward-streams/interfaces/IBalanceTracker.sol"; +import {ErrorsLib} from "../lib/ErrorsLib.sol"; +import {IRewardStreams} from "reward-streams/interfaces/IRewardStreams.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +/// @title BalanceForwarder contract +/// @custom:security-contact security@euler.xyz +/// @author Euler Labs (https://www.eulerlabs.com/) +/// @notice A generic contract to integrate with https://github.com/euler-xyz/reward-streams +abstract contract RewardsModule is IBalanceForwarder, Shared { + using StorageLib for *; + + event OptInStrategyRewards(address indexed strategy); + event OptOutStrategyRewards(address indexed strategy); + event EnableBalanceForwarder(address indexed _user); + event DisableBalanceForwarder(address indexed _user); + + /// @notice Opt in to strategy rewards + /// @param _strategy Strategy address + function optInStrategyRewards(address _strategy) external virtual nonReentrant { + AggregationVaultStorage storage $ = StorageLib._getAggregationVaultStorage(); + + if (!$.strategies[_strategy].active) revert ErrorsLib.InactiveStrategy(); + + IBalanceForwarder(_strategy).enableBalanceForwarder(); + + emit OptInStrategyRewards(_strategy); + } + + /// @notice Opt out of strategy rewards + /// @param _strategy Strategy address + function optOutStrategyRewards(address _strategy) external virtual nonReentrant { + IBalanceForwarder(_strategy).disableBalanceForwarder(); + + emit OptOutStrategyRewards(_strategy); + } + + /// @notice Claim a specific strategy rewards + /// @param _strategy Strategy address. + /// @param _reward The address of the reward token. + /// @param _recipient The address to receive the claimed reward tokens. + /// @param _forfeitRecentReward Whether to forfeit the recent rewards and not update the accumulator. + function claimStrategyReward(address _strategy, address _reward, address _recipient, bool _forfeitRecentReward) + external + virtual + nonReentrant + { + address rewardStreams = IBalanceForwarder(_strategy).balanceTrackerAddress(); + + IRewardStreams(rewardStreams).claimReward(_strategy, _reward, _recipient, _forfeitRecentReward); + } + + /// @notice Enables balance forwarding for the authenticated account + /// @dev Only the authenticated account can enable balance forwarding for itself + /// @dev Should call the IBalanceTracker hook with the current account's balance + function enableBalanceForwarder() external virtual nonReentrant { + address user = _msgSender(); + uint256 userBalance = IERC20(address(this)).balanceOf(user); + + _enableBalanceForwarder(user, userBalance); + } + + /// @notice Disables balance forwarding for the authenticated account + /// @dev Only the authenticated account can disable balance forwarding for itself + /// @dev Should call the IBalanceTracker hook with the account's balance of 0 + function disableBalanceForwarder() external virtual nonReentrant { + _disableBalanceForwarder(_msgSender()); + } + + /// @notice Retrieve the address of rewards contract, tracking changes in account's balances + /// @return The balance tracker address + function balanceTrackerAddress() external view returns (address) { + BalanceForwarderStorage storage $ = StorageLib._getBalanceForwarderStorage(); + + return address($.balanceTracker); + } + + /// @notice Retrieves boolean indicating if the account opted in to forward balance changes to the rewards contract + /// @param _account Address to query + /// @return True if balance forwarder is enabled + function balanceForwarderEnabled(address _account) external view returns (bool) { + return _balanceForwarderEnabled(_account); + } + + function _enableBalanceForwarder(address _sender, uint256 _senderBalance) internal { + BalanceForwarderStorage storage $ = StorageLib._getBalanceForwarderStorage(); + IBalanceTracker balanceTrackerCached = IBalanceTracker($.balanceTracker); + + if (address(balanceTrackerCached) == address(0)) revert ErrorsLib.NotSupported(); + if ($.isBalanceForwarderEnabled[_sender]) revert ErrorsLib.AlreadyEnabled(); + + $.isBalanceForwarderEnabled[_sender] = true; + balanceTrackerCached.balanceTrackerHook(_sender, _senderBalance, false); + + emit EnableBalanceForwarder(_sender); + } + + /// @notice Disables balance forwarding for the authenticated account + /// @dev Only the authenticated account can disable balance forwarding for itself + /// @dev Should call the IBalanceTracker hook with the account's balance of 0 + function _disableBalanceForwarder(address _sender) internal { + BalanceForwarderStorage storage $ = StorageLib._getBalanceForwarderStorage(); + IBalanceTracker balanceTrackerCached = IBalanceTracker($.balanceTracker); + + if (address(balanceTrackerCached) == address(0)) revert ErrorsLib.NotSupported(); + if (!$.isBalanceForwarderEnabled[_sender]) revert ErrorsLib.AlreadyDisabled(); + + $.isBalanceForwarderEnabled[_sender] = false; + balanceTrackerCached.balanceTrackerHook(_sender, 0, false); + + emit DisableBalanceForwarder(_sender); + } + + function _setBalanceTracker(address _balancerTracker) internal { + BalanceForwarderStorage storage $ = StorageLib._getBalanceForwarderStorage(); + + $.balanceTracker = _balancerTracker; + } + + /// @notice Retrieves boolean indicating if the account opted in to forward balance changes to the rewards contract + /// @param _account Address to query + /// @return True if balance forwarder is enabled + function _balanceForwarderEnabled(address _account) internal view returns (bool) { + BalanceForwarderStorage storage $ = StorageLib._getBalanceForwarderStorage(); + + return $.isBalanceForwarderEnabled[_account]; + } + + /// @notice Retrieve the address of rewards contract, tracking changes in account's balances + /// @return The balance tracker address + function _balanceTrackerAddress() internal view returns (address) { + BalanceForwarderStorage storage $ = StorageLib._getBalanceForwarderStorage(); + + return address($.balanceTracker); + } +} + +contract Rewards is RewardsModule {} diff --git a/test/common/FourSixTwoSixAggBase.t.sol b/test/common/FourSixTwoSixAggBase.t.sol index 6de1dd67..0268dcd7 100644 --- a/test/common/FourSixTwoSixAggBase.t.sol +++ b/test/common/FourSixTwoSixAggBase.t.sol @@ -12,6 +12,7 @@ import {FourSixTwoSixAggFactory} from "../../src/FourSixTwoSixAggFactory.sol"; import {WithdrawalQueue} from "../../src/WithdrawalQueue.sol"; import {IWithdrawalQueue} from "../../src/interface/IWithdrawalQueue.sol"; import {ErrorsLib} from "../../src/lib/ErrorsLib.sol"; +import {AllocationPoints} from "../../src/modules/AllocationPoints.sol"; contract FourSixTwoSixAggBase is EVaultTestBase { uint256 public constant CASH_RESERVE_ALLOCATION_POINTS = 1000e18; @@ -25,6 +26,7 @@ contract FourSixTwoSixAggBase is EVaultTestBase { Rewards rewardsImpl; Hooks hooksImpl; Fee feeModuleImpl; + AllocationPoints allocationPointsModuleImpl; // peripheries Rebalancer rebalancer; WithdrawalQueue withdrawalQueueImpl; @@ -45,6 +47,7 @@ contract FourSixTwoSixAggBase is EVaultTestBase { rewardsImpl = new Rewards(); hooksImpl = new Hooks(); feeModuleImpl = new Fee(); + allocationPointsModuleImpl = new AllocationPoints(); rebalancer = new Rebalancer(); withdrawalQueueImpl = new WithdrawalQueue(); @@ -54,6 +57,7 @@ contract FourSixTwoSixAggBase is EVaultTestBase { address(rewardsImpl), address(hooksImpl), address(feeModuleImpl), + address(allocationPointsModuleImpl), address(rebalancer), address(withdrawalQueueImpl) ); diff --git a/test/e2e/BalanceForwarderE2ETest.t.sol b/test/e2e/BalanceForwarderE2ETest.t.sol index b02fca97..9e45b94b 100644 --- a/test/e2e/BalanceForwarderE2ETest.t.sol +++ b/test/e2e/BalanceForwarderE2ETest.t.sol @@ -31,6 +31,7 @@ contract BalanceForwarderE2ETest is FourSixTwoSixAggBase { address(rewardsImpl), address(hooksImpl), address(feeModuleImpl), + address(allocationPointsModuleImpl), address(rebalancer), address(withdrawalQueueImpl) ); diff --git a/test/e2e/HooksE2ETest.t.sol b/test/e2e/HooksE2ETest.t.sol index a0e966c0..a7570b48 100644 --- a/test/e2e/HooksE2ETest.t.sol +++ b/test/e2e/HooksE2ETest.t.sol @@ -10,7 +10,7 @@ import { IRMTestDefault, TestERC20, IHookTarget, - HooksModule + ErrorsLib } from "../common/FourSixTwoSixAggBase.t.sol"; contract HooksE2ETest is FourSixTwoSixAggBase { @@ -45,7 +45,7 @@ contract HooksE2ETest is FourSixTwoSixAggBase { | fourSixTwoSixAgg.ADD_STRATEGY() | fourSixTwoSixAgg.REMOVE_STRATEGY(); vm.startPrank(manager); - vm.expectRevert(HooksModule.InvalidHooksTarget.selector); + vm.expectRevert(ErrorsLib.InvalidHooksTarget.selector); fourSixTwoSixAgg.setHooksConfig(address(0), expectedHookedFns); vm.stopPrank(); } @@ -56,7 +56,7 @@ contract HooksE2ETest is FourSixTwoSixAggBase { vm.startPrank(manager); address hooksContract = address(new NotHooksContract()); - vm.expectRevert(HooksModule.NotHooksContract.selector); + vm.expectRevert(ErrorsLib.NotHooksContract.selector); fourSixTwoSixAgg.setHooksConfig(hooksContract, expectedHookedFns); vm.stopPrank(); } @@ -65,7 +65,7 @@ contract HooksE2ETest is FourSixTwoSixAggBase { uint32 expectedHookedFns = 1 << 5; vm.startPrank(manager); address hooksContract = address(new HooksContract()); - vm.expectRevert(HooksModule.InvalidHookedFns.selector); + vm.expectRevert(ErrorsLib.InvalidHookedFns.selector); fourSixTwoSixAgg.setHooksConfig(hooksContract, expectedHookedFns); vm.stopPrank(); }