Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First batch of fixes for k-internal review 2 #24

Merged
merged 11 commits into from
Jun 13, 2024
12 changes: 4 additions & 8 deletions src/BalanceForwarder.sol
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ abstract contract BalanceForwarder is IBalanceForwarder {
}

function _enableBalanceForwarder(address _sender, uint256 _senderBalance) internal {
address cachedBalanceTracker = address(balanceTracker);

if (cachedBalanceTracker == address(0)) revert NotSupported();
if (address(balanceTracker) == address(0)) revert NotSupported();
if (isBalanceForwarderEnabled[_sender]) revert AlreadyEnabled();

isBalanceForwarderEnabled[_sender] = true;
IBalanceTracker(cachedBalanceTracker).balanceTrackerHook(_sender, _senderBalance, false);
balanceTracker.balanceTrackerHook(_sender, _senderBalance, false);

emit EnableBalanceForwarder(_sender);
}
Expand All @@ -63,13 +61,11 @@ abstract contract BalanceForwarder is IBalanceForwarder {
/// @dev Only the authenticated account can disable balance forwarding for itself
/// @dev Should call the IBalanceTracker hook with the account's balance of 0
function _disableBalanceForwarder(address _sender) internal {
address cachedBalanceTracker = address(balanceTracker);

if (cachedBalanceTracker == address(0)) revert NotSupported();
if (address(balanceTracker) == address(0)) revert NotSupported();
if (!isBalanceForwarderEnabled[_sender]) revert AlreadyDisabled();

isBalanceForwarderEnabled[_sender] = false;
IBalanceTracker(cachedBalanceTracker).balanceTrackerHook(_sender, 0, false);
balanceTracker.balanceTrackerHook(_sender, 0, false);

emit DisableBalanceForwarder(_sender);
}
Expand Down
153 changes: 56 additions & 97 deletions src/FourSixTwoSixAgg.sol

Large diffs are not rendered by default.

52 changes: 34 additions & 18 deletions src/Hooks.sol
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// SPDX-License-Identifier: GPL-2.0-or-later
pragma solidity ^0.8.0;

import {HooksLib, HooksType} from "./lib/HooksLib.sol";
import {HooksLib} from "./lib/HooksLib.sol";
import {IHookTarget} from "evk/src/interfaces/IHookTarget.sol";

abstract contract Hooks {
using HooksLib for HooksType;
using HooksLib for uint32;

error InvalidHooksTarget();
error NotHooksContract();
Expand All @@ -18,49 +18,65 @@ abstract contract Hooks {
uint32 public constant REMOVE_STRATEGY = 1 << 3;

uint32 constant ACTIONS_COUNTER = 1 << 4;
uint256 public constant HOOKS_MASK = 0x00000000000000000000000000000000000000000000000000000000FFFFFFFF;

/// @dev Contract with hooks implementation
address public hookTarget;
/// @dev Hooked functions
HooksType public hookedFns;
/// @dev storing the hooks target and kooked functions.
uint256 hooksConfig;

event SetHooksConfig(address indexed hooksTarget, uint32 hookedFns);

/// @notice Get the hooks contract and the hooked functions.
/// @return address Hooks contract.
/// @return uint32 Hooked functions.
function getHooksConfig() external view returns (address, uint32) {
return (hookTarget, hookedFns.toUint32());
return _getHooksConfig(hooksConfig);
}

/// @notice Set hooks contract and hooked functions.
/// @dev This funtion should be overriden to implement access control.
/// @param _hookTarget Hooks contract.
/// @dev This funtion should be overriden to implement access control and call _setHooksConfig().
/// @param _hooksTarget Hooks contract.
/// @param _hookedFns Hooked functions.
function setHooksConfig(address _hooksTarget, uint32 _hookedFns) public virtual;

/// @notice Set hooks contract and hooked functions.
/// @dev This funtion should be called when implementing setHooksConfig().
/// @param _hooksTarget Hooks contract.
/// @param _hookedFns Hooked functions.
function setHooksConfig(address _hookTarget, uint32 _hookedFns) public virtual {
if (_hookTarget != address(0) && IHookTarget(_hookTarget).isHookTarget() != IHookTarget.isHookTarget.selector) {
function _setHooksConfig(address _hooksTarget, uint32 _hookedFns) internal {
if (_hooksTarget != address(0) && IHookTarget(_hooksTarget).isHookTarget() != IHookTarget.isHookTarget.selector)
{
revert NotHooksContract();
}
if (_hookedFns != 0 && _hookTarget == address(0)) {
if (_hookedFns != 0 && _hooksTarget == address(0)) {
revert InvalidHooksTarget();
}
if (_hookedFns >= ACTIONS_COUNTER) revert InvalidHookedFns();

hookTarget = _hookTarget;
hookedFns = HooksType.wrap(_hookedFns);
hooksConfig = (uint256(uint160(_hooksTarget)) << 32) | uint256(_hookedFns);

emit SetHooksConfig(_hooksTarget, _hookedFns);
}

/// @notice Checks whether a hook has been installed for the function and if so, invokes the hook target.
/// @param _fn Function to check hook for.
/// @param _fn Function to call the hook for.
/// @param _caller Caller's address.
function _callHookTarget(uint32 _fn, address _caller) internal {
if (hookedFns.isNotSet(_fn)) return;
function _callHooksTarget(uint32 _fn, address _caller) internal {
(address target, uint32 hookedFns) = _getHooksConfig(hooksConfig);

address target = hookTarget;
if (hookedFns.isNotSet(_fn)) return;

(bool success, bytes memory data) = target.call(abi.encodePacked(msg.data, _caller));

if (!success) _revertBytes(data);
}

/// @notice Get the hooks contract and the hooked functions.
/// @return address Hooks contract.
/// @return uint32 Hooked functions.
function _getHooksConfig(uint256 _hooksConfig) internal pure returns (address, uint32) {
return (address(uint160(_hooksConfig >> 32)), uint32(_hooksConfig & HOOKS_MASK));
}

/// @dev Revert with call error or EmptyError
/// @param _errorMsg call revert message
function _revertBytes(bytes memory _errorMsg) private pure {
Expand Down
3 changes: 0 additions & 3 deletions src/Rebalancer.sol
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ contract Rebalancer {

IFourSixTwoSixAgg.Strategy memory strategyData = IFourSixTwoSixAgg(_curatedVault).getStrategy(_strategy);

// no rebalance if strategy have an allocated amount greater than cap
if ((strategyData.cap > 0) && (strategyData.allocated >= strategyData.cap)) return;

uint256 totalAllocationPointsCache = IFourSixTwoSixAgg(_curatedVault).totalAllocationPoints();
uint256 totalAssetsAllocatableCache = IFourSixTwoSixAgg(_curatedVault).totalAssetsAllocatable();
uint256 targetAllocation =
Expand Down
14 changes: 4 additions & 10 deletions src/lib/HooksLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,12 @@ pragma solidity ^0.8.0;
/// @dev This is copied from https://github.com/euler-xyz/euler-vault-kit/blob/30b0b9e36b0a912fe430c7482e9b3bb12d180a4e/src/EVault/shared/types/Flags.sol
library HooksLib {
/// @dev Are *all* of the Hooks in bitMask set?
function isSet(HooksType self, uint32 bitMask) internal pure returns (bool) {
return (HooksType.unwrap(self) & bitMask) == bitMask;
function isSet(uint32 _hookedFns, uint32 _fn) internal pure returns (bool) {
return (_hookedFns & _fn) == _fn;
}

/// @dev Are *none* of the Hooks in bitMask set?
function isNotSet(HooksType self, uint32 bitMask) internal pure returns (bool) {
return (HooksType.unwrap(self) & bitMask) == 0;
}

function toUint32(HooksType self) internal pure returns (uint32) {
return HooksType.unwrap(self);
function isNotSet(uint32 _hookedFns, uint32 _fn) internal pure returns (bool) {
return (_hookedFns & _fn) == 0;
}
}

type HooksType is uint32;
2 changes: 1 addition & 1 deletion test/common/FourSixTwoSixAggBase.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ contract FourSixTwoSixAggBase is EVaultTestBase {
vm.startPrank(deployer);
rebalancer = new Rebalancer();
fourSixTwoSixAgg = new FourSixTwoSixAgg(
evc,
address(evc),
address(0),
address(assetTST),
"assetTST_Agg",
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/BalanceForwarderE2ETest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ contract BalanceForwarderE2ETest is FourSixTwoSixAggBase {
trackingReward = address(new TrackingRewardStreams(address(evc), 2 weeks));

fourSixTwoSixAgg = new FourSixTwoSixAgg(
evc,
address(evc),
trackingReward,
address(assetTST),
"assetTST_Agg",
Expand Down
4 changes: 2 additions & 2 deletions test/unit/GulpTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ contract GulpTest is FourSixTwoSixAggBase {

function testGulpAfterNegativeYieldEqualToInterestLeft() public {
fourSixTwoSixAgg.gulp();
FourSixTwoSixAgg.ESRSlot memory ers = fourSixTwoSixAgg.getESRSlot();
FourSixTwoSixAgg.ESR memory ers = fourSixTwoSixAgg.getESRSlot();
assertEq(fourSixTwoSixAgg.interestAccrued(), 0);
assertEq(ers.interestLeft, 0);

Expand Down Expand Up @@ -106,7 +106,7 @@ contract GulpTest is FourSixTwoSixAggBase {

function testGulpAfterNegativeYieldBiggerThanInterestLeft() public {
fourSixTwoSixAgg.gulp();
FourSixTwoSixAgg.ESRSlot memory ers = fourSixTwoSixAgg.getESRSlot();
FourSixTwoSixAgg.ESR memory ers = fourSixTwoSixAgg.getESRSlot();
assertEq(fourSixTwoSixAgg.interestAccrued(), 0);
assertEq(ers.interestLeft, 0);

Expand Down
Loading