diff --git a/.forge-snapshots/DynamicFees.swap_updateDynamicFeeInBeforeSwap.snap b/.forge-snapshots/DynamicFees.swap_updateDynamicFeeInBeforeSwap.snap index 98aa37ccb..6a2708470 100644 --- a/.forge-snapshots/DynamicFees.swap_updateDynamicFeeInBeforeSwap.snap +++ b/.forge-snapshots/DynamicFees.swap_updateDynamicFeeInBeforeSwap.snap @@ -1 +1 @@ -141181 \ No newline at end of file +139992 \ No newline at end of file diff --git a/.forge-snapshots/DynamicFees.swap_withDynamicFee.snap b/.forge-snapshots/DynamicFees.swap_withDynamicFee.snap index b978e8a3e..5d173b590 100644 --- a/.forge-snapshots/DynamicFees.swap_withDynamicFee.snap +++ b/.forge-snapshots/DynamicFees.swap_withDynamicFee.snap @@ -1 +1 @@ -90176 \ No newline at end of file +89366 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.addLiquidity.snap b/.forge-snapshots/PoolManager.addLiquidity.snap index b36da47c1..d098236d6 100644 --- a/.forge-snapshots/PoolManager.addLiquidity.snap +++ b/.forge-snapshots/PoolManager.addLiquidity.snap @@ -1 +1 @@ -145959 \ No newline at end of file +144743 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.addLiquidity_withEmptyHook.snap b/.forge-snapshots/PoolManager.addLiquidity_withEmptyHook.snap index e2bd4e7ac..9d842f32a 100644 --- a/.forge-snapshots/PoolManager.addLiquidity_withEmptyHook.snap +++ b/.forge-snapshots/PoolManager.addLiquidity_withEmptyHook.snap @@ -1 +1 @@ -265799 \ No newline at end of file +264566 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.addLiquidity_withNative.snap b/.forge-snapshots/PoolManager.addLiquidity_withNative.snap index 2e5687233..f004ea00c 100644 --- a/.forge-snapshots/PoolManager.addLiquidity_withNative.snap +++ b/.forge-snapshots/PoolManager.addLiquidity_withNative.snap @@ -1 +1 @@ -140642 \ No newline at end of file +139426 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.bytecodeSize.snap b/.forge-snapshots/PoolManager.bytecodeSize.snap index 3cf278ca1..838b1c452 100644 --- a/.forge-snapshots/PoolManager.bytecodeSize.snap +++ b/.forge-snapshots/PoolManager.bytecodeSize.snap @@ -1 +1 @@ -23026 \ No newline at end of file +22598 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.collectProtocolFees_erc20.snap b/.forge-snapshots/PoolManager.collectProtocolFees_erc20.snap index 36cf22341..b9fe27a3f 100644 --- a/.forge-snapshots/PoolManager.collectProtocolFees_erc20.snap +++ b/.forge-snapshots/PoolManager.collectProtocolFees_erc20.snap @@ -1 +1 @@ -24938 \ No newline at end of file +24961 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.collectProtocolFees_native.snap b/.forge-snapshots/PoolManager.collectProtocolFees_native.snap index 5d4c1b5c0..984502b79 100644 --- a/.forge-snapshots/PoolManager.collectProtocolFees_native.snap +++ b/.forge-snapshots/PoolManager.collectProtocolFees_native.snap @@ -1 +1 @@ -36611 \ No newline at end of file +36634 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.donate_gasWith1Token.snap b/.forge-snapshots/PoolManager.donate_gasWith1Token.snap index 3829a9cc6..0715154f4 100644 --- a/.forge-snapshots/PoolManager.donate_gasWith1Token.snap +++ b/.forge-snapshots/PoolManager.donate_gasWith1Token.snap @@ -1 +1 @@ -101654 \ No newline at end of file +101137 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.donate_gasWith2Tokens.snap b/.forge-snapshots/PoolManager.donate_gasWith2Tokens.snap index f59cb30b4..0bc63eb2a 100644 --- a/.forge-snapshots/PoolManager.donate_gasWith2Tokens.snap +++ b/.forge-snapshots/PoolManager.donate_gasWith2Tokens.snap @@ -1 +1 @@ -128667 \ No newline at end of file +132127 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.initialize.snap b/.forge-snapshots/PoolManager.initialize.snap index df1ae0e82..33f4191f2 100644 --- a/.forge-snapshots/PoolManager.initialize.snap +++ b/.forge-snapshots/PoolManager.initialize.snap @@ -1 +1 @@ -51819 \ No newline at end of file +51266 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.removeLiquidity.snap b/.forge-snapshots/PoolManager.removeLiquidity.snap index 24c0b40b7..1c6861c96 100644 --- a/.forge-snapshots/PoolManager.removeLiquidity.snap +++ b/.forge-snapshots/PoolManager.removeLiquidity.snap @@ -1 +1 @@ -150224 \ No newline at end of file +148876 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.removeLiquidity_withEmptyHook.snap b/.forge-snapshots/PoolManager.removeLiquidity_withEmptyHook.snap index 8c729b8ed..cc078112e 100644 --- a/.forge-snapshots/PoolManager.removeLiquidity_withEmptyHook.snap +++ b/.forge-snapshots/PoolManager.removeLiquidity_withEmptyHook.snap @@ -1 +1 @@ -56574 \ No newline at end of file +55226 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.removeLiquidity_withNative.snap b/.forge-snapshots/PoolManager.removeLiquidity_withNative.snap index ab62f35c6..52870c04c 100644 --- a/.forge-snapshots/PoolManager.removeLiquidity_withNative.snap +++ b/.forge-snapshots/PoolManager.removeLiquidity_withNative.snap @@ -1 +1 @@ -148760 \ No newline at end of file +147412 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_againstLiquidity.snap b/.forge-snapshots/PoolManager.swap_againstLiquidity.snap index 0c800252f..8209798a6 100644 --- a/.forge-snapshots/PoolManager.swap_againstLiquidity.snap +++ b/.forge-snapshots/PoolManager.swap_againstLiquidity.snap @@ -1 +1 @@ -60874 \ No newline at end of file +60138 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_againstLiquidityWithNativeToken.snap b/.forge-snapshots/PoolManager.swap_againstLiquidityWithNativeToken.snap index addfbca9d..82b18dab6 100644 --- a/.forge-snapshots/PoolManager.swap_againstLiquidityWithNativeToken.snap +++ b/.forge-snapshots/PoolManager.swap_againstLiquidityWithNativeToken.snap @@ -1 +1 @@ -72868 \ No newline at end of file +72132 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_burn6909ForInput.snap b/.forge-snapshots/PoolManager.swap_burn6909ForInput.snap index b01d299d4..957d1faa2 100644 --- a/.forge-snapshots/PoolManager.swap_burn6909ForInput.snap +++ b/.forge-snapshots/PoolManager.swap_burn6909ForInput.snap @@ -1 +1 @@ -80988 \ No newline at end of file +80315 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_burnNative6909ForInput.snap b/.forge-snapshots/PoolManager.swap_burnNative6909ForInput.snap index 3630e5b1a..3a61378cc 100644 --- a/.forge-snapshots/PoolManager.swap_burnNative6909ForInput.snap +++ b/.forge-snapshots/PoolManager.swap_burnNative6909ForInput.snap @@ -1 +1 @@ -76943 \ No newline at end of file +76165 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_mintNativeOutputAs6909.snap b/.forge-snapshots/PoolManager.swap_mintNativeOutputAs6909.snap index 566df1219..9fbea3b86 100644 --- a/.forge-snapshots/PoolManager.swap_mintNativeOutputAs6909.snap +++ b/.forge-snapshots/PoolManager.swap_mintNativeOutputAs6909.snap @@ -1 +1 @@ -139268 \ No newline at end of file +138660 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_mintOutputAs6909.snap b/.forge-snapshots/PoolManager.swap_mintOutputAs6909.snap index 982481321..d8ad01be3 100644 --- a/.forge-snapshots/PoolManager.swap_mintOutputAs6909.snap +++ b/.forge-snapshots/PoolManager.swap_mintOutputAs6909.snap @@ -1 +1 @@ -156077 \ No newline at end of file +155291 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_simple.snap b/.forge-snapshots/PoolManager.swap_simple.snap index ba806de9b..a150e2e48 100644 --- a/.forge-snapshots/PoolManager.swap_simple.snap +++ b/.forge-snapshots/PoolManager.swap_simple.snap @@ -1 +1 @@ -147285 \ No newline at end of file +146476 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_simpleWithNative.snap b/.forge-snapshots/PoolManager.swap_simpleWithNative.snap index df22d79b2..bb2d92b47 100644 --- a/.forge-snapshots/PoolManager.swap_simpleWithNative.snap +++ b/.forge-snapshots/PoolManager.swap_simpleWithNative.snap @@ -1 +1 @@ -133409 \ No newline at end of file +132600 \ No newline at end of file diff --git a/.forge-snapshots/PoolManager.swap_withHooks.snap b/.forge-snapshots/PoolManager.swap_withHooks.snap index 160b64063..e182446d6 100644 --- a/.forge-snapshots/PoolManager.swap_withHooks.snap +++ b/.forge-snapshots/PoolManager.swap_withHooks.snap @@ -1 +1 @@ -60849 \ No newline at end of file +60113 \ No newline at end of file diff --git a/.forge-snapshots/SkipCallsTestsHook.swap_skipsHookCallifHookIsCaller.snap b/.forge-snapshots/SkipCallsTestsHook.swap_skipsHookCallifHookIsCaller.snap new file mode 100644 index 000000000..828d0d44f --- /dev/null +++ b/.forge-snapshots/SkipCallsTestsHook.swap_skipsHookCallifHookIsCaller.snap @@ -0,0 +1 @@ +155723 \ No newline at end of file diff --git a/.forge-snapshots/flipTick_gasCostOfFlippingATickThatResultsInDeletingAWord.snap b/.forge-snapshots/flipTick_gasCostOfFlippingATickThatResultsInDeletingAWord.snap index 3819e9389..e3b929b4e 100644 --- a/.forge-snapshots/flipTick_gasCostOfFlippingATickThatResultsInDeletingAWord.snap +++ b/.forge-snapshots/flipTick_gasCostOfFlippingATickThatResultsInDeletingAWord.snap @@ -1 +1 @@ -5409 \ No newline at end of file +5407 \ No newline at end of file diff --git a/.forge-snapshots/flipTick_gasCostOfFlippingFirstTickInWordToInitialized.snap b/.forge-snapshots/flipTick_gasCostOfFlippingFirstTickInWordToInitialized.snap index 0cd31fe26..e602141e2 100644 --- a/.forge-snapshots/flipTick_gasCostOfFlippingFirstTickInWordToInitialized.snap +++ b/.forge-snapshots/flipTick_gasCostOfFlippingFirstTickInWordToInitialized.snap @@ -1 +1 @@ -22506 \ No newline at end of file +22504 \ No newline at end of file diff --git a/.forge-snapshots/flipTick_gasCostOfFlippingSecondTickInWordToInitialized.snap b/.forge-snapshots/flipTick_gasCostOfFlippingSecondTickInWordToInitialized.snap index 38966a2c0..aa0e4d419 100644 --- a/.forge-snapshots/flipTick_gasCostOfFlippingSecondTickInWordToInitialized.snap +++ b/.forge-snapshots/flipTick_gasCostOfFlippingSecondTickInWordToInitialized.snap @@ -1 +1 @@ -5515 \ No newline at end of file +5513 \ No newline at end of file diff --git a/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostForEntireWord.snap b/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostForEntireWord.snap index 13f668d19..0f3866130 100644 --- a/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostForEntireWord.snap +++ b/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostForEntireWord.snap @@ -1 +1 @@ -2592 \ No newline at end of file +2578 \ No newline at end of file diff --git a/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostJustBelowBoundary.snap b/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostJustBelowBoundary.snap index 13f668d19..0f3866130 100644 --- a/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostJustBelowBoundary.snap +++ b/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostJustBelowBoundary.snap @@ -1 +1 @@ -2592 \ No newline at end of file +2578 \ No newline at end of file diff --git a/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostOnBoundary.snap b/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostOnBoundary.snap index 13f668d19..0f3866130 100644 --- a/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostOnBoundary.snap +++ b/.forge-snapshots/nextInitializedTickWithinOneWord_lteFalse_gasCostOnBoundary.snap @@ -1 +1 @@ -2592 \ No newline at end of file +2578 \ No newline at end of file diff --git a/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostForEntireWord.snap b/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostForEntireWord.snap index 0c52fc6da..a3412f309 100644 --- a/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostForEntireWord.snap +++ b/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostForEntireWord.snap @@ -1 +1 @@ -2591 \ No newline at end of file +2556 \ No newline at end of file diff --git a/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostJustBelowBoundary.snap b/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostJustBelowBoundary.snap index 7b34ebbf2..47fd2ab34 100644 --- a/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostJustBelowBoundary.snap +++ b/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostJustBelowBoundary.snap @@ -1 +1 @@ -2900 \ No newline at end of file +2865 \ No newline at end of file diff --git a/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostOnBoundary.snap b/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostOnBoundary.snap index 0c52fc6da..a3412f309 100644 --- a/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostOnBoundary.snap +++ b/.forge-snapshots/nextInitializedTickWithinOneWord_lteTrue_gasCostOnBoundary.snap @@ -1 +1 @@ -2591 \ No newline at end of file +2556 \ No newline at end of file diff --git a/.forge-snapshots/swap against liquidity.snap b/.forge-snapshots/swap against liquidity.snap new file mode 100644 index 000000000..2f58c626a --- /dev/null +++ b/.forge-snapshots/swap against liquidity.snap @@ -0,0 +1 @@ +60135 \ No newline at end of file diff --git a/.forge-snapshots/swap burn 6909 for input.snap b/.forge-snapshots/swap burn 6909 for input.snap new file mode 100644 index 000000000..edbef057b --- /dev/null +++ b/.forge-snapshots/swap burn 6909 for input.snap @@ -0,0 +1 @@ +80312 \ No newline at end of file diff --git a/.forge-snapshots/swap burn native 6909 for input.snap b/.forge-snapshots/swap burn native 6909 for input.snap new file mode 100644 index 000000000..795e97fab --- /dev/null +++ b/.forge-snapshots/swap burn native 6909 for input.snap @@ -0,0 +1 @@ +76162 \ No newline at end of file diff --git a/.forge-snapshots/swap skips hook call if hook is caller.snap b/.forge-snapshots/swap skips hook call if hook is caller.snap new file mode 100644 index 000000000..828d0d44f --- /dev/null +++ b/.forge-snapshots/swap skips hook call if hook is caller.snap @@ -0,0 +1 @@ +155723 \ No newline at end of file diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 000000000..4de459a7d --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,89 @@ +name: code coverage + +on: + pull_request: + branches: + - main + +jobs: + comment-forge-coverage: + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write + + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Install foundry + uses: foundry-rs/foundry-toolchain@v1 + with: + version: nightly + - name: Run Forge build + run: | + forge --version + forge build --sizes + id: build + - name: Run forge coverage + id: coverage + run: | + { + echo 'COVERAGE<> "$GITHUB_OUTPUT" + env: + FOUNDRY_RPC_URL: '${{ secrets.RPC_URL }}' + + - name: Check coverage is updated + uses: actions/github-script@v5 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const file = "coverage.txt" + if(!fs.existsSync(file)) { + console.log("Nothing to check"); + return + } + const currentCoverage = fs.readFileSync(file, "utf8").trim(); + const newCoverage = (`${{ steps.coverage.outputs.COVERAGE }}`).trim(); + if (newCoverage != currentCoverage) { + core.setFailed(`Code coverage not updated. Run : forge coverage | grep -v 'test/' | tail -n +6 > coverage.txt`); + } + + - name: Comment on PR + id: comment + uses: actions/github-script@v5 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const {data: comments} = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + }) + + const botComment = comments.find(comment => comment.user.id === 41898282) + + const output = `${{ steps.coverage.outputs.COVERAGE }}`; + const commentBody = `Forge code coverage:\n${output}\n`; + + if (botComment) { + github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: commentBody + }) + } else { + github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: commentBody + }); + } \ No newline at end of file diff --git a/src/Owned.sol b/src/Owned.sol deleted file mode 100644 index 7f31fbab8..000000000 --- a/src/Owned.sol +++ /dev/null @@ -1,29 +0,0 @@ -// SPDX-License-Identifier: BUSL-1.1 -pragma solidity ^0.8.20; - -contract Owned { - address public owner; - bytes12 private STORAGE_PLACEHOLDER; - - error InvalidCaller(); - - /// @notice Emitted when the owner of the factory is changed - /// @param oldOwner The owner before the owner was changed - /// @param newOwner The owner after the owner was changed - event OwnerChanged(address indexed oldOwner, address indexed newOwner); - - modifier onlyOwner() { - if (msg.sender != owner) revert InvalidCaller(); - _; - } - - constructor() { - owner = msg.sender; - emit OwnerChanged(address(0), msg.sender); - } - - function setOwner(address _owner) external onlyOwner { - emit OwnerChanged(owner, _owner); - owner = _owner; - } -} diff --git a/src/PoolManager.sol b/src/PoolManager.sol index cfea4768c..a77dc91bc 100644 --- a/src/PoolManager.sol +++ b/src/PoolManager.sol @@ -10,7 +10,6 @@ import {Currency, CurrencyLibrary} from "./types/Currency.sol"; import {PoolKey} from "./types/PoolKey.sol"; import {TickMath} from "./libraries/TickMath.sol"; import {NoDelegateCall} from "./NoDelegateCall.sol"; -import {Owned} from "./Owned.sol"; import {IHooks} from "./interfaces/IHooks.sol"; import {IPoolManager} from "./interfaces/IPoolManager.sol"; import {IUnlockCallback} from "./interfaces/callback/IUnlockCallback.sol"; @@ -48,12 +47,16 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim constructor(uint256 controllerGasLimit) ProtocolFees(controllerGasLimit) {} + function _getPool(PoolId id) internal view override returns (Pool.State storage) { + return pools[id]; + } + /// @inheritdoc IPoolManager function getSlot0(PoolId id) external view override - returns (uint160 sqrtPriceX96, int24 tick, uint16 protocolFee, uint24 swapFee) + returns (uint160 sqrtPriceX96, int24 tick, uint24 protocolFee, uint24 swapFee) { Pool.Slot0 memory slot0 = pools[id].slot0; @@ -104,6 +107,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim function initialize(PoolKey memory key, uint160 sqrtPriceX96, bytes calldata hookData) external override + noDelegateCall returns (int24 tick) { // see TickBitmap.sol for overflow conditions that can arise from tick spacing being too large @@ -112,12 +116,12 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim if (key.currency0 >= key.currency1) revert CurrenciesOutOfOrderOrEqual(); if (!key.hooks.isValidHookAddress(key.fee)) revert Hooks.HookAddressNotValid(address(key.hooks)); - uint24 swapFee = key.fee.getSwapFee(); + uint24 swapFee = key.fee.getInitialSwapFee(); key.hooks.beforeInitialize(key, sqrtPriceX96, hookData); PoolId id = key.toId(); - (, uint16 protocolFee) = _fetchProtocolFee(key); + (, uint24 protocolFee) = _fetchProtocolFee(key); tick = pools[id].initialize(sqrtPriceX96, protocolFee, swapFee); @@ -128,7 +132,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim } /// @inheritdoc IPoolManager - function unlock(bytes calldata data) external override returns (bytes memory result) { + function unlock(bytes calldata data) external override noDelegateCall returns (bytes memory result) { if (Lock.isUnlocked()) revert AlreadyUnlocked(); Lock.unlock(); @@ -172,7 +176,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim PoolKey memory key, IPoolManager.ModifyLiquidityParams memory params, bytes calldata hookData - ) external override noDelegateCall onlyWhenUnlocked returns (BalanceDelta delta) { + ) external override onlyWhenUnlocked returns (BalanceDelta delta) { PoolId id = key.toId(); _checkPoolInitialized(id); @@ -199,7 +203,6 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim function swap(PoolKey memory key, IPoolManager.SwapParams memory params, bytes calldata hookData) external override - noDelegateCall onlyWhenUnlocked returns (BalanceDelta delta) { @@ -222,11 +225,9 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim _accountPoolBalanceDelta(key, delta); - // the fee is on the input currency - unchecked { - if (feeForProtocol > 0) { - protocolFeesAccrued[params.zeroForOne ? key.currency0 : key.currency1] += feeForProtocol; - } + // The fee is on the input currency. + if (feeForProtocol > 0) { + _updateProtocolFees(params.zeroForOne ? key.currency0 : key.currency1, feeForProtocol); } emit Swap( @@ -240,7 +241,6 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim function donate(PoolKey memory key, uint256 amount0, uint256 amount1, bytes calldata hookData) external override - noDelegateCall onlyWhenUnlocked returns (BalanceDelta delta) { @@ -257,7 +257,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim } /// @inheritdoc IPoolManager - function take(Currency currency, address to, uint256 amount) external override noDelegateCall onlyWhenUnlocked { + function take(Currency currency, address to, uint256 amount) external override onlyWhenUnlocked { // subtraction must be safe _accountDelta(currency, -(amount.toInt128())); if (!currency.isNative()) reservesOf[currency] -= amount; @@ -265,14 +265,7 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim } /// @inheritdoc IPoolManager - function settle(Currency currency) - external - payable - override - noDelegateCall - onlyWhenUnlocked - returns (uint256 paid) - { + function settle(Currency currency) external payable override onlyWhenUnlocked returns (uint256 paid) { if (currency.isNative()) { paid = msg.value; } else { @@ -285,26 +278,18 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim } /// @inheritdoc IPoolManager - function mint(address to, uint256 id, uint256 amount) external override noDelegateCall onlyWhenUnlocked { + function mint(address to, uint256 id, uint256 amount) external override onlyWhenUnlocked { // subtraction must be safe _accountDelta(CurrencyLibrary.fromId(id), -(amount.toInt128())); _mint(to, id, amount); } /// @inheritdoc IPoolManager - function burn(address from, uint256 id, uint256 amount) external override noDelegateCall onlyWhenUnlocked { + function burn(address from, uint256 id, uint256 amount) external override onlyWhenUnlocked { _accountDelta(CurrencyLibrary.fromId(id), amount.toInt128()); _burnFrom(from, id, amount); } - function setProtocolFee(PoolKey memory key) external { - (bool success, uint16 newProtocolFee) = _fetchProtocolFee(key); - if (!success) revert ProtocolFeeControllerCallFailedOrInvalidResult(); - PoolId id = key.toId(); - pools[id].setProtocolFee(newProtocolFee); - emit ProtocolFeeUpdated(id, newProtocolFee); - } - function updateDynamicSwapFee(PoolKey memory key, uint24 newDynamicSwapFee) external { if (!key.fee.isDynamicFee() || msg.sender != address(key.hooks)) revert UnauthorizedDynamicSwapFeeUpdate(); newDynamicSwapFee.validate(); @@ -343,4 +328,12 @@ contract PoolManager is IPoolManager, ProtocolFees, NoDelegateCall, ERC6909Claim function getPoolBitmapInfo(PoolId id, int16 word) external view returns (uint256 tickBitmap) { return pools[id].getPoolBitmapInfo(word); } + + function getFeeGrowthGlobals(PoolId id) + external + view + returns (uint256 feeGrowthGlobal0x128, uint256 feeGrowthGlobal1x128) + { + return pools[id].getFeeGrowthGlobals(); + } } diff --git a/src/ProtocolFees.sol b/src/ProtocolFees.sol index b70ad12b8..b7d2b062b 100644 --- a/src/ProtocolFees.sol +++ b/src/ProtocolFees.sol @@ -4,14 +4,17 @@ pragma solidity ^0.8.19; import {Currency, CurrencyLibrary} from "./types/Currency.sol"; import {IProtocolFeeController} from "./interfaces/IProtocolFeeController.sol"; import {IProtocolFees} from "./interfaces/IProtocolFees.sol"; -import {Pool} from "./libraries/Pool.sol"; import {PoolKey} from "./types/PoolKey.sol"; -import {Owned} from "./Owned.sol"; +import {ProtocolFeeLibrary} from "./libraries/ProtocolFeeLibrary.sol"; +import {Owned} from "solmate/auth/Owned.sol"; +import {PoolId, PoolIdLibrary} from "./types/PoolId.sol"; +import {Pool} from "./libraries/Pool.sol"; abstract contract ProtocolFees is IProtocolFees, Owned { using CurrencyLibrary for Currency; - - uint8 public constant MIN_PROTOCOL_FEE_DENOMINATOR = 4; + using ProtocolFeeLibrary for uint24; + using PoolIdLibrary for PoolKey; + using Pool for Pool.State; mapping(Currency currency => uint256) public protocolFeesAccrued; @@ -19,15 +22,44 @@ abstract contract ProtocolFees is IProtocolFees, Owned { uint256 private immutable controllerGasLimit; - constructor(uint256 _controllerGasLimit) { + constructor(uint256 _controllerGasLimit) Owned(msg.sender) { controllerGasLimit = _controllerGasLimit; } + /// @inheritdoc IProtocolFees + function setProtocolFeeController(IProtocolFeeController controller) external onlyOwner { + protocolFeeController = controller; + emit ProtocolFeeControllerUpdated(address(controller)); + } + + /// @inheritdoc IProtocolFees + function setProtocolFee(PoolKey memory key) external { + (bool success, uint24 newProtocolFee) = _fetchProtocolFee(key); + if (!success) revert ProtocolFeeControllerCallFailedOrInvalidResult(); + PoolId id = key.toId(); + _getPool(id).setProtocolFee(newProtocolFee); + emit ProtocolFeeUpdated(id, newProtocolFee); + } + + /// @inheritdoc IProtocolFees + function collectProtocolFees(address recipient, Currency currency, uint256 amount) + external + returns (uint256 amountCollected) + { + if (msg.sender != address(protocolFeeController)) revert InvalidCaller(); + + amountCollected = (amount == 0) ? protocolFeesAccrued[currency] : amount; + protocolFeesAccrued[currency] -= amountCollected; + currency.transfer(recipient, amountCollected); + } + + function _getPool(PoolId id) internal virtual returns (Pool.State storage); + /// @notice Fetch the protocol fees for a given pool, returning false if the call fails or the returned fees are invalid. /// @dev to prevent an invalid protocol fee controller from blocking pools from being initialized /// the success of this function is NOT checked on initialize and if the call fails, the protocol fees are set to 0. /// @dev the success of this function must be checked when called in setProtocolFee - function _fetchProtocolFee(PoolKey memory key) internal returns (bool success, uint16 protocolFees) { + function _fetchProtocolFee(PoolKey memory key) internal returns (bool success, uint24 protocolFees) { if (address(protocolFeeController) != address(0)) { // note that EIP-150 mandates that calls requesting more than 63/64ths of remaining gas // will be allotted no more than this amount, so controllerGasLimit must be set with this @@ -44,40 +76,16 @@ abstract contract ProtocolFees is IProtocolFees, Owned { assembly { returnData := mload(add(_data, 0x20)) } - // Ensure return data does not overflow a uint16 and that the underlying fees are within bounds. - (success, protocolFees) = returnData == uint16(returnData) && _isValidProtocolFee(uint16(returnData)) - ? (true, uint16(returnData)) + // Ensure return data does not overflow a uint24 and that the underlying fees are within bounds. + (success, protocolFees) = (returnData == uint24(returnData)) && uint24(returnData).validate() + ? (true, uint24(returnData)) : (false, 0); } } - function _isValidProtocolFee(uint16 fee) internal pure returns (bool) { - if (fee != 0) { - uint16 fee0 = fee % 256; - uint16 fee1 = fee >> 8; - // The fee is specified as a denominator so it cannot be LESS than the MIN_PROTOCOL_FEE_DENOMINATOR (unless it is 0). - if ( - (fee0 != 0 && fee0 < MIN_PROTOCOL_FEE_DENOMINATOR) || (fee1 != 0 && fee1 < MIN_PROTOCOL_FEE_DENOMINATOR) - ) { - return false; - } + function _updateProtocolFees(Currency currency, uint256 amount) internal { + unchecked { + protocolFeesAccrued[currency] += amount; } - return true; - } - - function setProtocolFeeController(IProtocolFeeController controller) external onlyOwner { - protocolFeeController = controller; - emit ProtocolFeeControllerUpdated(address(controller)); - } - - function collectProtocolFees(address recipient, Currency currency, uint256 amount) - external - returns (uint256 amountCollected) - { - if (msg.sender != address(protocolFeeController)) revert InvalidCaller(); - - amountCollected = (amount == 0) ? protocolFeesAccrued[currency] : amount; - protocolFeesAccrued[currency] -= amountCollected; - currency.transfer(recipient, amountCollected); } } diff --git a/src/interfaces/IPoolManager.sol b/src/interfaces/IPoolManager.sol index 752afdfec..b256d8a91 100644 --- a/src/interfaces/IPoolManager.sol +++ b/src/interfaces/IPoolManager.sol @@ -81,8 +81,6 @@ interface IPoolManager is IProtocolFees, IERC6909Claims { uint24 fee ); - event ProtocolFeeUpdated(PoolId indexed id, uint16 protocolFee); - /// @notice Returns the constant representing the maximum tickSpacing for an initialized pool key function MAX_TICK_SPACING() external view returns (int24); @@ -93,7 +91,7 @@ interface IPoolManager is IProtocolFees, IERC6909Claims { function getSlot0(PoolId id) external view - returns (uint160 sqrtPriceX96, int24 tick, uint16 protocolFee, uint24 swapFee); + returns (uint160 sqrtPriceX96, int24 tick, uint24 protocolFee, uint24 swapFee); /// @notice Get the current value of liquidity of the given pool function getLiquidity(PoolId id) external view returns (uint128 liquidity); @@ -110,6 +108,12 @@ interface IPoolManager is IProtocolFees, IERC6909Claims { /// @notice Getter for the bitmap given the poolId and word position function getPoolBitmapInfo(PoolId id, int16 word) external view returns (uint256 tickBitmap); + /// @notice Getter for the fee growth globals for the given poolId + function getFeeGrowthGlobals(PoolId id) + external + view + returns (uint256 feeGrowthGlobal0, uint256 feeGrowthGlobal1); + /// @notice Get the position struct for a specified pool and position function getPosition(PoolId id, address owner, int24 tickLower, int24 tickUpper) external @@ -187,10 +191,6 @@ interface IPoolManager is IProtocolFees, IERC6909Claims { /// @notice Called by the user to pay what is owed function settle(Currency token) external payable returns (uint256 paid); - /// @notice Sets the protocol's swap fee for the given pool - /// Protocol fees are always a portion of the LP swap fee that is owed. If that fee is 0, no protocol fees will accrue even if it is set to > 0. - function setProtocolFee(PoolKey memory key) external; - /// @notice Updates the pools swap fees for the a pool that has enabled dynamic swap fees. function updateDynamicSwapFee(PoolKey memory key, uint24 newDynamicSwapFee) external; diff --git a/src/interfaces/IProtocolFeeController.sol b/src/interfaces/IProtocolFeeController.sol index fbe8799c8..7f2d72ae5 100644 --- a/src/interfaces/IProtocolFeeController.sol +++ b/src/interfaces/IProtocolFeeController.sol @@ -7,5 +7,5 @@ interface IProtocolFeeController { /// @notice Returns the protocol fees for a pool given the conditions of this contract /// @param poolKey The pool key to identify the pool. The controller may want to use attributes on the pool /// to determine the protocol fee, hence the entire key is needed. - function protocolFeeForPool(PoolKey memory poolKey) external view returns (uint16); + function protocolFeeForPool(PoolKey memory poolKey) external view returns (uint24); } diff --git a/src/interfaces/IProtocolFees.sol b/src/interfaces/IProtocolFees.sol index 2113f17b5..aeed1a0e4 100644 --- a/src/interfaces/IProtocolFees.sol +++ b/src/interfaces/IProtocolFees.sol @@ -2,20 +2,35 @@ pragma solidity ^0.8.19; import {Currency} from "../types/Currency.sol"; +import {IProtocolFeeController} from "../interfaces/IProtocolFeeController.sol"; +import {PoolId} from "../types/PoolId.sol"; +import {PoolKey} from "../types/PoolKey.sol"; interface IProtocolFees { /// @notice Thrown when not enough gas is provided to look up the protocol fee error ProtocolFeeCannotBeFetched(); /// @notice Thrown when the call to fetch the protocol fee reverts or returns invalid data. error ProtocolFeeControllerCallFailedOrInvalidResult(); - /// @notice Thrown when a pool does not have a dynamic fee. - error FeeNotDynamic(); + + /// @notice Thrown when collectProtocolFees is not called by the controller. + error InvalidCaller(); event ProtocolFeeControllerUpdated(address protocolFeeController); - /// @notice Returns the minimum denominator for the protocol fee, which restricts it to a maximum of 25% - function MIN_PROTOCOL_FEE_DENOMINATOR() external view returns (uint8); + event ProtocolFeeUpdated(PoolId indexed id, uint24 protocolFee); /// @notice Given a currency address, returns the protocol fees accrued in that currency function protocolFeesAccrued(Currency) external view returns (uint256); + + /// @notice Sets the protocol's swap fee for the given pool + /// Protocol fees are always a portion of the LP swap fee that is owed. If that fee is 0, no protocol fees will accrue even if it is set to > 0. + function setProtocolFee(PoolKey memory key) external; + + /// @notice Sets the protocol fee controller + function setProtocolFeeController(IProtocolFeeController) external; + + /// @notice Collects the protocol fees for a given recipient and currency, returning the amount collected + function collectProtocolFees(address, Currency, uint256) external returns (uint256); + + function protocolFeeController() external view returns (IProtocolFeeController); } diff --git a/src/libraries/Hooks.sol b/src/libraries/Hooks.sol index 77873154e..895876a2a 100644 --- a/src/libraries/Hooks.sol +++ b/src/libraries/Hooks.sol @@ -71,13 +71,13 @@ library Hooks { } /// @notice Ensures that the hook address includes at least one hook flag or dynamic fees, or is the 0 address - /// @param hook The hook to verify - function isValidHookAddress(IHooks hook, uint24 fee) internal pure returns (bool) { + /// @param self The hook to verify + function isValidHookAddress(IHooks self, uint24 fee) internal pure returns (bool) { // If there is no hook contract set, then fee cannot be dynamic // If a hook contract is set, it must have at least 1 flag set, or have a dynamic fee - return address(hook) == address(0) + return address(self) == address(0) ? !fee.isDynamicFee() - : (uint160(address(hook)) >= AFTER_DONATE_FLAG || fee.isDynamicFee()); + : (uint160(address(self)) >= AFTER_DONATE_FLAG || fee.isDynamicFee()); } /// @notice performs a hook call using the given calldata on the given hook @@ -103,9 +103,17 @@ library Hooks { } } + /// @notice modifier to prevent calling a hook if they initiated the action + modifier noSelfCall(IHooks self) { + if (msg.sender != address(self)) { + _; + } + } + /// @notice calls beforeInitialize hook if permissioned and validates return value function beforeInitialize(IHooks self, PoolKey memory key, uint160 sqrtPriceX96, bytes calldata hookData) internal + noSelfCall(self) { if (self.hasPermission(BEFORE_INITIALIZE_FLAG)) { self.callHook( @@ -117,6 +125,7 @@ library Hooks { /// @notice calls afterInitialize hook if permissioned and validates return value function afterInitialize(IHooks self, PoolKey memory key, uint160 sqrtPriceX96, int24 tick, bytes calldata hookData) internal + noSelfCall(self) { if (self.hasPermission(AFTER_INITIALIZE_FLAG)) { self.callHook( @@ -131,10 +140,10 @@ library Hooks { PoolKey memory key, IPoolManager.ModifyLiquidityParams memory params, bytes calldata hookData - ) internal { - if (params.liquidityDelta > 0 && key.hooks.hasPermission(BEFORE_ADD_LIQUIDITY_FLAG)) { + ) internal noSelfCall(self) { + if (params.liquidityDelta > 0 && self.hasPermission(BEFORE_ADD_LIQUIDITY_FLAG)) { self.callHook(abi.encodeWithSelector(IHooks.beforeAddLiquidity.selector, msg.sender, key, params, hookData)); - } else if (params.liquidityDelta <= 0 && key.hooks.hasPermission(BEFORE_REMOVE_LIQUIDITY_FLAG)) { + } else if (params.liquidityDelta <= 0 && self.hasPermission(BEFORE_REMOVE_LIQUIDITY_FLAG)) { self.callHook( abi.encodeWithSelector(IHooks.beforeRemoveLiquidity.selector, msg.sender, key, params, hookData) ); @@ -148,12 +157,12 @@ library Hooks { IPoolManager.ModifyLiquidityParams memory params, BalanceDelta delta, bytes calldata hookData - ) internal { - if (params.liquidityDelta > 0 && key.hooks.hasPermission(AFTER_ADD_LIQUIDITY_FLAG)) { + ) internal noSelfCall(self) { + if (params.liquidityDelta > 0 && self.hasPermission(AFTER_ADD_LIQUIDITY_FLAG)) { self.callHook( abi.encodeWithSelector(IHooks.afterAddLiquidity.selector, msg.sender, key, params, delta, hookData) ); - } else if (params.liquidityDelta <= 0 && key.hooks.hasPermission(AFTER_REMOVE_LIQUIDITY_FLAG)) { + } else if (params.liquidityDelta <= 0 && self.hasPermission(AFTER_REMOVE_LIQUIDITY_FLAG)) { self.callHook( abi.encodeWithSelector(IHooks.afterRemoveLiquidity.selector, msg.sender, key, params, delta, hookData) ); @@ -163,8 +172,9 @@ library Hooks { /// @notice calls beforeSwap hook if permissioned and validates return value function beforeSwap(IHooks self, PoolKey memory key, IPoolManager.SwapParams memory params, bytes calldata hookData) internal + noSelfCall(self) { - if (key.hooks.hasPermission(BEFORE_SWAP_FLAG)) { + if (self.hasPermission(BEFORE_SWAP_FLAG)) { self.callHook(abi.encodeWithSelector(IHooks.beforeSwap.selector, msg.sender, key, params, hookData)); } } @@ -176,8 +186,8 @@ library Hooks { IPoolManager.SwapParams memory params, BalanceDelta delta, bytes calldata hookData - ) internal { - if (key.hooks.hasPermission(AFTER_SWAP_FLAG)) { + ) internal noSelfCall(self) { + if (self.hasPermission(AFTER_SWAP_FLAG)) { self.callHook(abi.encodeWithSelector(IHooks.afterSwap.selector, msg.sender, key, params, delta, hookData)); } } @@ -185,8 +195,9 @@ library Hooks { /// @notice calls beforeDonate hook if permissioned and validates return value function beforeDonate(IHooks self, PoolKey memory key, uint256 amount0, uint256 amount1, bytes calldata hookData) internal + noSelfCall(self) { - if (key.hooks.hasPermission(BEFORE_DONATE_FLAG)) { + if (self.hasPermission(BEFORE_DONATE_FLAG)) { self.callHook( abi.encodeWithSelector(IHooks.beforeDonate.selector, msg.sender, key, amount0, amount1, hookData) ); @@ -196,8 +207,9 @@ library Hooks { /// @notice calls afterDonate hook if permissioned and validates return value function afterDonate(IHooks self, PoolKey memory key, uint256 amount0, uint256 amount1, bytes calldata hookData) internal + noSelfCall(self) { - if (key.hooks.hasPermission(AFTER_DONATE_FLAG)) { + if (self.hasPermission(AFTER_DONATE_FLAG)) { self.callHook( abi.encodeWithSelector(IHooks.afterDonate.selector, msg.sender, key, amount0, amount1, hookData) ); diff --git a/src/libraries/Pool.sol b/src/libraries/Pool.sol index 70652d799..242a74197 100644 --- a/src/libraries/Pool.sol +++ b/src/libraries/Pool.sol @@ -10,6 +10,7 @@ import {TickMath} from "./TickMath.sol"; import {SqrtPriceMath} from "./SqrtPriceMath.sol"; import {SwapMath} from "./SwapMath.sol"; import {BalanceDelta, toBalanceDelta} from "../types/BalanceDelta.sol"; +import {ProtocolFeeLibrary} from "./ProtocolFeeLibrary.sol"; library Pool { using SafeCast for *; @@ -17,6 +18,7 @@ library Pool { using Position for mapping(bytes32 => Position.Info); using Position for Position.Info; using Pool for State; + using ProtocolFeeLibrary for uint24; /// @notice Thrown when tickLower is not below tickUpper /// @param tickLower The invalid tickLower @@ -64,11 +66,10 @@ library Pool { uint160 sqrtPriceX96; // the current tick int24 tick; - // protocol swap fee represented as integer denominator (1/x), taken as a % of the LP swap fee - // upper 8 bits are for 1->0, and the lower 8 are for 0->1 - // the minimum permitted denominator is 4 - meaning the maximum protocol fee is 25% - // granularity is increments of 0.38% (100/type(uint8).max) - uint16 protocolFee; + // protocol swap fee, taken as a % of the LP swap fee + // upper 12 bits are for 1->0, and the lower 12 are for 0->1 + // the maximum is 2500 - meaning the maximum protocol fee is 25% + uint24 protocolFee; // used for the swap fee, either static at initialize or dynamic via hook uint24 swapFee; } @@ -103,7 +104,7 @@ library Pool { if (tickUpper > TickMath.MAX_TICK) revert TickUpperOutOfBounds(tickUpper); } - function initialize(State storage self, uint160 sqrtPriceX96, uint16 protocolFee, uint24 swapFee) + function initialize(State storage self, uint160 sqrtPriceX96, uint24 protocolFee, uint24 swapFee) internal returns (int24 tick) { @@ -114,7 +115,7 @@ library Pool { self.slot0 = Slot0({sqrtPriceX96: sqrtPriceX96, tick: tick, protocolFee: protocolFee, swapFee: swapFee}); } - function setProtocolFee(State storage self, uint16 protocolFee) internal { + function setProtocolFee(State storage self, uint24 protocolFee) internal { if (self.isNotInitialized()) revert PoolNotInitialized(); self.slot0.protocolFee = protocolFee; @@ -155,7 +156,10 @@ library Pool { internal returns (BalanceDelta result) { - checkTicks(params.tickLower, params.tickUpper); + int128 liquidityDelta = params.liquidityDelta; + int24 tickLower = params.tickLower; + int24 tickUpper = params.tickUpper; + checkTicks(tickLower, tickUpper); uint256 feesOwed0; uint256 feesOwed1; @@ -163,82 +167,77 @@ library Pool { ModifyLiquidityState memory state; // if we need to update the ticks, do it - if (params.liquidityDelta != 0) { + if (liquidityDelta != 0) { (state.flippedLower, state.liquidityGrossAfterLower) = - updateTick(self, params.tickLower, params.liquidityDelta, false); - (state.flippedUpper, state.liquidityGrossAfterUpper) = - updateTick(self, params.tickUpper, params.liquidityDelta, true); + updateTick(self, tickLower, liquidityDelta, false); + (state.flippedUpper, state.liquidityGrossAfterUpper) = updateTick(self, tickUpper, liquidityDelta, true); - if (params.liquidityDelta > 0) { + // `>` and `>=` are logically equivalent here but `>=` is cheaper + if (liquidityDelta >= 0) { uint128 maxLiquidityPerTick = tickSpacingToMaxLiquidityPerTick(params.tickSpacing); if (state.liquidityGrossAfterLower > maxLiquidityPerTick) { - revert TickLiquidityOverflow(params.tickLower); + revert TickLiquidityOverflow(tickLower); } if (state.liquidityGrossAfterUpper > maxLiquidityPerTick) { - revert TickLiquidityOverflow(params.tickUpper); + revert TickLiquidityOverflow(tickUpper); } } if (state.flippedLower) { - self.tickBitmap.flipTick(params.tickLower, params.tickSpacing); + self.tickBitmap.flipTick(tickLower, params.tickSpacing); } if (state.flippedUpper) { - self.tickBitmap.flipTick(params.tickUpper, params.tickSpacing); + self.tickBitmap.flipTick(tickUpper, params.tickSpacing); } } - (state.feeGrowthInside0X128, state.feeGrowthInside1X128) = - getFeeGrowthInside(self, params.tickLower, params.tickUpper); + (state.feeGrowthInside0X128, state.feeGrowthInside1X128) = getFeeGrowthInside(self, tickLower, tickUpper); - (feesOwed0, feesOwed1) = self.positions.get(params.owner, params.tickLower, params.tickUpper).update( - params.liquidityDelta, state.feeGrowthInside0X128, state.feeGrowthInside1X128 - ); + Position.Info storage position = self.positions.get(params.owner, tickLower, tickUpper); + (feesOwed0, feesOwed1) = + position.update(liquidityDelta, state.feeGrowthInside0X128, state.feeGrowthInside1X128); // clear any tick data that is no longer needed - if (params.liquidityDelta < 0) { + if (liquidityDelta < 0) { if (state.flippedLower) { - clearTick(self, params.tickLower); + clearTick(self, tickLower); } if (state.flippedUpper) { - clearTick(self, params.tickUpper); + clearTick(self, tickUpper); } } } - if (params.liquidityDelta != 0) { - if (self.slot0.tick < params.tickLower) { + if (liquidityDelta != 0) { + int24 tick = self.slot0.tick; + uint160 sqrtPriceX96 = self.slot0.sqrtPriceX96; + if (tick < tickLower) { // current tick is below the passed range; liquidity can only become in range by crossing from left to // right, when we'll need _more_ currency0 (it's becoming more valuable) so user must provide it result = toBalanceDelta( SqrtPriceMath.getAmount0Delta( - TickMath.getSqrtRatioAtTick(params.tickLower), - TickMath.getSqrtRatioAtTick(params.tickUpper), - params.liquidityDelta + TickMath.getSqrtRatioAtTick(tickLower), TickMath.getSqrtRatioAtTick(tickUpper), liquidityDelta ).toInt128(), 0 ); - } else if (self.slot0.tick < params.tickUpper) { + } else if (tick < tickUpper) { result = toBalanceDelta( - SqrtPriceMath.getAmount0Delta( - self.slot0.sqrtPriceX96, TickMath.getSqrtRatioAtTick(params.tickUpper), params.liquidityDelta - ).toInt128(), - SqrtPriceMath.getAmount1Delta( - TickMath.getSqrtRatioAtTick(params.tickLower), self.slot0.sqrtPriceX96, params.liquidityDelta - ).toInt128() + SqrtPriceMath.getAmount0Delta(sqrtPriceX96, TickMath.getSqrtRatioAtTick(tickUpper), liquidityDelta) + .toInt128(), + SqrtPriceMath.getAmount1Delta(TickMath.getSqrtRatioAtTick(tickLower), sqrtPriceX96, liquidityDelta) + .toInt128() ); - self.liquidity = params.liquidityDelta < 0 - ? self.liquidity - uint128(-params.liquidityDelta) - : self.liquidity + uint128(params.liquidityDelta); + self.liquidity = liquidityDelta < 0 + ? self.liquidity - uint128(-liquidityDelta) + : self.liquidity + uint128(liquidityDelta); } else { // current tick is above the passed range; liquidity can only become in range by crossing from right to // left, when we'll need _more_ currency1 (it's becoming more valuable) so user must provide it result = toBalanceDelta( 0, SqrtPriceMath.getAmount1Delta( - TickMath.getSqrtRatioAtTick(params.tickLower), - TickMath.getSqrtRatioAtTick(params.tickUpper), - params.liquidityDelta + TickMath.getSqrtRatioAtTick(tickLower), TickMath.getSqrtRatioAtTick(tickUpper), liquidityDelta ).toInt128() ); } @@ -252,7 +251,7 @@ library Pool { // liquidity at the beginning of the swap uint128 liquidityStart; // the protocol fee for the input token - uint8 protocolFee; + uint16 protocolFee; } // the top level state of the swap, the results of which are recorded in storage at the end @@ -305,7 +304,8 @@ library Pool { Slot0 memory slot0Start = self.slot0; swapFee = slot0Start.swapFee; - if (params.zeroForOne) { + bool zeroForOne = params.zeroForOne; + if (zeroForOne) { if (params.sqrtPriceLimitX96 >= slot0Start.sqrtPriceX96) { revert PriceLimitAlreadyExceeded(slot0Start.sqrtPriceX96, params.sqrtPriceLimitX96); } @@ -323,7 +323,7 @@ library Pool { SwapCache memory cache = SwapCache({ liquidityStart: self.liquidity, - protocolFee: params.zeroForOne ? uint8(slot0Start.protocolFee % 256) : uint8(slot0Start.protocolFee >> 8) + protocolFee: zeroForOne ? slot0Start.protocolFee.getZeroForOneFee() : slot0Start.protocolFee.getOneForZeroFee() }); bool exactInput = params.amountSpecified < 0; @@ -333,7 +333,7 @@ library Pool { amountCalculated: 0, sqrtPriceX96: slot0Start.sqrtPriceX96, tick: slot0Start.tick, - feeGrowthGlobalX128: params.zeroForOne ? self.feeGrowthGlobal0X128 : self.feeGrowthGlobal1X128, + feeGrowthGlobalX128: zeroForOne ? self.feeGrowthGlobal0X128 : self.feeGrowthGlobal1X128, liquidity: cache.liquidityStart }); @@ -343,7 +343,7 @@ library Pool { step.sqrtPriceStartX96 = state.sqrtPriceX96; (step.tickNext, step.initialized) = - self.tickBitmap.nextInitializedTickWithinOneWord(state.tick, params.tickSpacing, params.zeroForOne); + self.tickBitmap.nextInitializedTickWithinOneWord(state.tick, params.tickSpacing, zeroForOne); // ensure that we do not overshoot the min/max tick, as the tick bitmap is not aware of these bounds if (step.tickNext < TickMath.MIN_TICK) { @@ -359,7 +359,7 @@ library Pool { (state.sqrtPriceX96, step.amountIn, step.amountOut, step.feeAmount) = SwapMath.computeSwapStep( state.sqrtPriceX96, ( - params.zeroForOne + zeroForOne ? step.sqrtPriceNextX96 < params.sqrtPriceLimitX96 : step.sqrtPriceNextX96 > params.sqrtPriceLimitX96 ) ? params.sqrtPriceLimitX96 : step.sqrtPriceNextX96, @@ -383,9 +383,9 @@ library Pool { // if the protocol fee is on, calculate how much is owed, decrement feeAmount, and increment protocolFee if (cache.protocolFee > 0) { - // A: calculate the amount of the fee that should go to the protocol - uint256 delta = step.feeAmount / cache.protocolFee; - // A: subtract it from the regular fee and add it to the protocol fee + // calculate the amount of the fee that should go to the protocol + uint256 delta = step.feeAmount * cache.protocolFee / ProtocolFeeLibrary.BIPS_DENOMINATOR; + // subtract it from the regular fee and add it to the protocol fee unchecked { step.feeAmount -= delta; feeForProtocol += delta; @@ -406,13 +406,13 @@ library Pool { int128 liquidityNet = Pool.crossTick( self, step.tickNext, - (params.zeroForOne ? state.feeGrowthGlobalX128 : self.feeGrowthGlobal0X128), - (params.zeroForOne ? self.feeGrowthGlobal1X128 : state.feeGrowthGlobalX128) + (zeroForOne ? state.feeGrowthGlobalX128 : self.feeGrowthGlobal0X128), + (zeroForOne ? self.feeGrowthGlobal1X128 : state.feeGrowthGlobalX128) ); // if we're moving leftward, we interpret liquidityNet as the opposite sign // safe because liquidityNet cannot be type(int128).min unchecked { - if (params.zeroForOne) liquidityNet = -liquidityNet; + if (zeroForOne) liquidityNet = -liquidityNet; } state.liquidity = liquidityNet < 0 @@ -421,7 +421,7 @@ library Pool { } unchecked { - state.tick = params.zeroForOne ? step.tickNext - 1 : step.tickNext; + state.tick = zeroForOne ? step.tickNext - 1 : step.tickNext; } } else if (state.sqrtPriceX96 != step.sqrtPriceStartX96) { // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved @@ -435,14 +435,14 @@ library Pool { if (cache.liquidityStart != state.liquidity) self.liquidity = state.liquidity; // update fee growth global - if (params.zeroForOne) { + if (zeroForOne) { self.feeGrowthGlobal0X128 = state.feeGrowthGlobalX128; } else { self.feeGrowthGlobal1X128 = state.feeGrowthGlobalX128; } unchecked { - if (params.zeroForOne == exactInput) { + if (zeroForOne == exactInput) { result = toBalanceDelta( (params.amountSpecified - state.amountSpecifiedRemaining).toInt128(), state.amountCalculated.toInt128() diff --git a/src/libraries/PoolGetters.sol b/src/libraries/PoolGetters.sol index 6a8eac5c3..a05761b34 100644 --- a/src/libraries/PoolGetters.sol +++ b/src/libraries/PoolGetters.sol @@ -11,4 +11,12 @@ library PoolGetters { function getPoolBitmapInfo(Pool.State storage pool, int16 word) internal view returns (uint256 tickBitmap) { return pool.tickBitmap[word]; } + + function getFeeGrowthGlobals(Pool.State storage pool) + internal + view + returns (uint256 feeGrowthGlobal0x128, uint256 feeGrowthGlobal1x128) + { + return (pool.feeGrowthGlobal0X128, pool.feeGrowthGlobal1X128); + } } diff --git a/src/libraries/Position.sol b/src/libraries/Position.sol index cb0149d6a..ec0b5f3e4 100644 --- a/src/libraries/Position.sol +++ b/src/libraries/Position.sol @@ -29,9 +29,18 @@ library Position { function get(mapping(bytes32 => Info) storage self, address owner, int24 tickLower, int24 tickUpper) internal view - returns (Position.Info storage position) + returns (Info storage position) { - position = self[keccak256(abi.encodePacked(owner, tickLower, tickUpper))]; + // positionKey = keccak256(abi.encodePacked(owner, tickLower, tickUpper)) + bytes32 positionKey; + /// @solidity memory-safe-assembly + assembly { + mstore(0x06, tickUpper) // [0x23, 0x26) + mstore(0x03, tickLower) // [0x20, 0x23) + mstore(0, owner) // [0x0c, 0x20) + positionKey := keccak256(0x0c, 0x1a) + } + position = self[positionKey]; } /// @notice Credits accumulated fees to a user's position @@ -47,26 +56,23 @@ library Position { uint256 feeGrowthInside0X128, uint256 feeGrowthInside1X128 ) internal returns (uint256 feesOwed0, uint256 feesOwed1) { - Info memory _self = self; + uint128 liquidity = self.liquidity; uint128 liquidityNext; if (liquidityDelta == 0) { - if (_self.liquidity == 0) revert CannotUpdateEmptyPosition(); // disallow pokes for 0 liquidity positions - liquidityNext = _self.liquidity; + if (liquidity == 0) revert CannotUpdateEmptyPosition(); // disallow pokes for 0 liquidity positions + liquidityNext = liquidity; } else { - liquidityNext = liquidityDelta < 0 - ? _self.liquidity - uint128(-liquidityDelta) - : _self.liquidity + uint128(liquidityDelta); + liquidityNext = + liquidityDelta < 0 ? liquidity - uint128(-liquidityDelta) : liquidity + uint128(liquidityDelta); } // calculate accumulated fees. overflow in the subtraction of fee growth is expected unchecked { - feesOwed0 = FullMath.mulDiv( - feeGrowthInside0X128 - _self.feeGrowthInside0LastX128, _self.liquidity, FixedPoint128.Q128 - ); - feesOwed1 = FullMath.mulDiv( - feeGrowthInside1X128 - _self.feeGrowthInside1LastX128, _self.liquidity, FixedPoint128.Q128 - ); + feesOwed0 = + FullMath.mulDiv(feeGrowthInside0X128 - self.feeGrowthInside0LastX128, liquidity, FixedPoint128.Q128); + feesOwed1 = + FullMath.mulDiv(feeGrowthInside1X128 - self.feeGrowthInside1LastX128, liquidity, FixedPoint128.Q128); } // update the position diff --git a/src/libraries/ProtocolFeeLibrary.sol b/src/libraries/ProtocolFeeLibrary.sol new file mode 100644 index 000000000..28d75dc3d --- /dev/null +++ b/src/libraries/ProtocolFeeLibrary.sol @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.20; + +library ProtocolFeeLibrary { + // Max protocol fee is 25% (2500 bips) + uint16 public constant MAX_PROTOCOL_FEE = 2500; + + // Total bips + uint16 internal constant BIPS_DENOMINATOR = 10_000; + + function getZeroForOneFee(uint24 self) internal pure returns (uint16) { + return uint16(self & (4096 - 1)); + } + + function getOneForZeroFee(uint24 self) internal pure returns (uint16) { + return uint16(self >> 12); + } + + function validate(uint24 self) internal pure returns (bool) { + if (self != 0) { + uint16 fee0 = getZeroForOneFee(self); + uint16 fee1 = getOneForZeroFee(self); + // The fee is represented in bips so it cannot be GREATER than the MAX_PROTOCOL_FEE. + if ((fee0 > MAX_PROTOCOL_FEE) || (fee1 > MAX_PROTOCOL_FEE)) { + return false; + } + } + return true; + } +} diff --git a/src/libraries/SwapFeeLibrary.sol b/src/libraries/SwapFeeLibrary.sol index d90b8c654..7fec6095b 100644 --- a/src/libraries/SwapFeeLibrary.sol +++ b/src/libraries/SwapFeeLibrary.sol @@ -23,7 +23,7 @@ library SwapFeeLibrary { if (self > MAX_SWAP_FEE) revert FeeTooLarge(); } - function getSwapFee(uint24 self) internal pure returns (uint24 swapFee) { + function getInitialSwapFee(uint24 self) internal pure returns (uint24 swapFee) { // the initial fee for a dynamic fee pool is 0 if (self.isDynamicFee()) return 0; swapFee = self & STATIC_FEE_MASK; diff --git a/src/libraries/TickBitmap.sol b/src/libraries/TickBitmap.sol index 0bdefcaa1..b04d5d19e 100644 --- a/src/libraries/TickBitmap.sol +++ b/src/libraries/TickBitmap.sol @@ -19,7 +19,7 @@ library TickBitmap { function position(int24 tick) internal pure returns (int16 wordPos, uint8 bitPos) { unchecked { wordPos = int16(tick >> 8); - bitPos = uint8(int8(tick % 256)); + bitPos = uint8(int8(tick & (256 - 1))); } } diff --git a/src/libraries/TickMath.sol b/src/libraries/TickMath.sol index cd8e5cf01..4c9b95b94 100644 --- a/src/libraries/TickMath.sol +++ b/src/libraries/TickMath.sol @@ -76,7 +76,7 @@ library TickMath { // this divides by 1<<32 rounding up to go from a Q128.128 to a Q128.96. // we then downcast because we know the result always fits within 160 bits due to our tick input constraint // we round up in the division so getTickAtSqrtRatio of the output price is always consistent - sqrtPriceX96 = uint160((ratio >> 32) + (ratio % (1 << 32) == 0 ? 0 : 1)); + sqrtPriceX96 = uint160((ratio >> 32) + (ratio & ((1 << 32) - 1) == 0 ? 0 : 1)); } } diff --git a/src/test/PoolDonateTest.sol b/src/test/PoolDonateTest.sol index d568c32e0..55d95c392 100644 --- a/src/test/PoolDonateTest.sol +++ b/src/test/PoolDonateTest.sol @@ -43,23 +43,17 @@ contract PoolDonateTest is PoolTestBase { CallbackData memory data = abi.decode(rawData, (CallbackData)); - (, uint256 poolBalanceBefore0, int256 deltaBefore0) = - _fetchBalances(data.key.currency0, data.sender, address(this)); - (, uint256 poolBalanceBefore1, int256 deltaBefore1) = - _fetchBalances(data.key.currency1, data.sender, address(this)); + (,, int256 deltaBefore0) = _fetchBalances(data.key.currency0, data.sender, address(this)); + (,, int256 deltaBefore1) = _fetchBalances(data.key.currency1, data.sender, address(this)); require(deltaBefore0 == 0, "deltaBefore0 is not 0"); require(deltaBefore1 == 0, "deltaBefore1 is not 0"); BalanceDelta delta = manager.donate(data.key, data.amount0, data.amount1, data.hookData); - (, uint256 poolBalanceAfter0, int256 deltaAfter0) = - _fetchBalances(data.key.currency0, data.sender, address(this)); - (, uint256 poolBalanceAfter1, int256 deltaAfter1) = - _fetchBalances(data.key.currency1, data.sender, address(this)); + (,, int256 deltaAfter0) = _fetchBalances(data.key.currency0, data.sender, address(this)); + (,, int256 deltaAfter1) = _fetchBalances(data.key.currency1, data.sender, address(this)); - require(poolBalanceBefore0 == poolBalanceAfter0, "poolBalanceBefore0 is not equal to poolBalanceAfter0"); - require(poolBalanceBefore1 == poolBalanceAfter1, "poolBalanceBefore1 is not equal to poolBalanceAfter1"); require(deltaAfter0 == -int256(data.amount0), "deltaAfter0 is not equal to -int256(data.amount0)"); require(deltaAfter1 == -int256(data.amount1), "deltaAfter1 is not equal to -int256(data.amount1)"); diff --git a/src/test/PoolSwapTest.sol b/src/test/PoolSwapTest.sol index 0817017b3..255fb678a 100644 --- a/src/test/PoolSwapTest.sol +++ b/src/test/PoolSwapTest.sol @@ -52,23 +52,16 @@ contract PoolSwapTest is PoolTestBase { CallbackData memory data = abi.decode(rawData, (CallbackData)); - (, uint256 poolBalanceBefore0, int256 deltaBefore0) = - _fetchBalances(data.key.currency0, data.sender, address(this)); - (, uint256 poolBalanceBefore1, int256 deltaBefore1) = - _fetchBalances(data.key.currency1, data.sender, address(this)); + (,, int256 deltaBefore0) = _fetchBalances(data.key.currency0, data.sender, address(this)); + (,, int256 deltaBefore1) = _fetchBalances(data.key.currency1, data.sender, address(this)); require(deltaBefore0 == 0, "deltaBefore0 is not equal to 0"); require(deltaBefore1 == 0, "deltaBefore1 is not equal to 0"); BalanceDelta delta = manager.swap(data.key, data.params, data.hookData); - (, uint256 poolBalanceAfter0, int256 deltaAfter0) = - _fetchBalances(data.key.currency0, data.sender, address(this)); - (, uint256 poolBalanceAfter1, int256 deltaAfter1) = - _fetchBalances(data.key.currency1, data.sender, address(this)); - - require(poolBalanceBefore0 == poolBalanceAfter0, "poolBalanceBefore0 is not equal to poolBalanceAfter0"); - require(poolBalanceBefore1 == poolBalanceAfter1, "poolBalanceBefore1 is not equal to poolBalanceAfter1"); + (,, int256 deltaAfter0) = _fetchBalances(data.key.currency0, data.sender, address(this)); + (,, int256 deltaAfter1) = _fetchBalances(data.key.currency1, data.sender, address(this)); if (data.params.zeroForOne) { if (data.params.amountSpecified < 0) { diff --git a/src/test/ProtocolFeeControllerTest.sol b/src/test/ProtocolFeeControllerTest.sol index 1b46080d8..30aec61a7 100644 --- a/src/test/ProtocolFeeControllerTest.sol +++ b/src/test/ProtocolFeeControllerTest.sol @@ -9,36 +9,36 @@ import {PoolKey} from "../types/PoolKey.sol"; contract ProtocolFeeControllerTest is IProtocolFeeController { using PoolIdLibrary for PoolKey; - mapping(PoolId => uint16) public swapFeeForPool; + mapping(PoolId => uint24) public protocolFee; - function protocolFeeForPool(PoolKey memory key) external view returns (uint16) { - return swapFeeForPool[key.toId()]; + function protocolFeeForPool(PoolKey memory key) external view returns (uint24) { + return protocolFee[key.toId()]; } // for tests to set pool protocol fees - function setSwapFeeForPool(PoolId id, uint16 fee) external { - swapFeeForPool[id] = fee; + function setProtocolFeeForPool(PoolId id, uint24 fee) external { + protocolFee[id] = fee; } } /// @notice Reverts on call contract RevertingProtocolFeeControllerTest is IProtocolFeeController { - function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint16) { + function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint24) { revert(); } } /// @notice Returns an out of bounds protocol fee contract OutOfBoundsProtocolFeeControllerTest is IProtocolFeeController { - function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint16) { - // set both swap and withdraw fees to 1, which is less than MIN_PROTOCOL_FEE_DENOMINATOR - return 0x001001; + function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint24) { + // set both swap fees to 2501, which is greater than MAX_PROTOCOL_FEE + return 0x9C59C5; } } /// @notice Return a value that overflows a uint24 contract OverflowProtocolFeeControllerTest is IProtocolFeeController { - function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint16) { + function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint24) { assembly { let ptr := mload(0x40) mstore(ptr, 0xFFFFAAA001) @@ -49,7 +49,7 @@ contract OverflowProtocolFeeControllerTest is IProtocolFeeController { /// @notice Returns data that is larger than a word contract InvalidReturnSizeProtocolFeeControllerTest is IProtocolFeeController { - function protocolFeeForPool(PoolKey memory /* key */ ) external view returns (uint16) { + function protocolFeeForPool(PoolKey memory /* key */ ) external view returns (uint24) { address a = address(this); assembly { let ptr := mload(0x40) diff --git a/src/test/SkipCallsTestHook.sol b/src/test/SkipCallsTestHook.sol new file mode 100644 index 000000000..0ebc72185 --- /dev/null +++ b/src/test/SkipCallsTestHook.sol @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.24; + +import {Hooks} from "../libraries/Hooks.sol"; +import {BaseTestHooks} from "./BaseTestHooks.sol"; +import {IHooks} from "../interfaces/IHooks.sol"; +import {IPoolManager} from "../interfaces/IPoolManager.sol"; +import {PoolKey} from "../types/PoolKey.sol"; +import {BalanceDelta} from "../types/BalanceDelta.sol"; +import {PoolId, PoolIdLibrary} from "../types/PoolId.sol"; +import {IERC20Minimal} from "../interfaces/external/IERC20Minimal.sol"; +import {CurrencyLibrary, Currency} from "../types/Currency.sol"; +import {PoolTestBase} from "./PoolTestBase.sol"; +import {Constants} from "../../test/utils/Constants.sol"; +import {Test} from "forge-std/Test.sol"; + +contract SkipCallsTestHook is BaseTestHooks, Test { + using PoolIdLibrary for PoolKey; + using Hooks for IHooks; + + uint256 public counter; + IPoolManager manager; + + function setManager(IPoolManager _manager) external { + manager = _manager; + } + + function beforeInitialize(address, PoolKey calldata key, uint160 sqrtPriceX96, bytes calldata hookData) + external + override + returns (bytes4) + { + counter++; + _initialize(key, sqrtPriceX96, hookData); + return IHooks.beforeInitialize.selector; + } + + function afterInitialize(address, PoolKey calldata key, uint160 sqrtPriceX96, int24, bytes calldata hookData) + external + override + returns (bytes4) + { + counter++; + _initialize(key, sqrtPriceX96, hookData); + return IHooks.afterInitialize.selector; + } + + function beforeAddLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + bytes calldata hookData + ) external override returns (bytes4) { + counter++; + _addLiquidity(key, params, hookData); + return IHooks.beforeAddLiquidity.selector; + } + + function afterAddLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + BalanceDelta, + bytes calldata hookData + ) external override returns (bytes4) { + counter++; + _addLiquidity(key, params, hookData); + return IHooks.afterAddLiquidity.selector; + } + + function beforeRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + bytes calldata hookData + ) external override returns (bytes4) { + counter++; + _removeLiquidity(key, params, hookData); + return IHooks.beforeRemoveLiquidity.selector; + } + + function afterRemoveLiquidity( + address, + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams calldata params, + BalanceDelta, + bytes calldata hookData + ) external override returns (bytes4) { + counter++; + _removeLiquidity(key, params, hookData); + return IHooks.afterRemoveLiquidity.selector; + } + + function beforeSwap(address, PoolKey calldata key, IPoolManager.SwapParams calldata params, bytes calldata hookData) + external + override + returns (bytes4) + { + counter++; + _swap(key, params, hookData); + return IHooks.beforeSwap.selector; + } + + function afterSwap( + address, + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta, + bytes calldata hookData + ) external override returns (bytes4) { + counter++; + _swap(key, params, hookData); + return IHooks.afterSwap.selector; + } + + function beforeDonate(address, PoolKey calldata key, uint256 amt0, uint256 amt1, bytes calldata hookData) + external + override + returns (bytes4) + { + counter++; + _donate(key, amt0, amt1, hookData); + return IHooks.beforeDonate.selector; + } + + function afterDonate(address, PoolKey calldata key, uint256 amt0, uint256 amt1, bytes calldata hookData) + external + override + returns (bytes4) + { + counter++; + _donate(key, amt0, amt1, hookData); + return IHooks.afterDonate.selector; + } + + function _initialize(PoolKey memory key, uint160 sqrtPriceX96, bytes calldata hookData) public { + // initialize a new pool with different fee + key.fee = 2000; + IPoolManager(manager).initialize(key, sqrtPriceX96, hookData); + } + + function _swap(PoolKey calldata key, IPoolManager.SwapParams memory params, bytes calldata hookData) public { + IPoolManager(manager).swap(key, params, hookData); + address payer = abi.decode(hookData, (address)); + int256 delta0 = IPoolManager(manager).currencyDelta(address(this), key.currency0); + assertEq(delta0, params.amountSpecified); + int256 delta1 = IPoolManager(manager).currencyDelta(address(this), key.currency1); + assert(delta1 > 0); + IERC20Minimal(Currency.unwrap(key.currency0)).transferFrom(payer, address(manager), uint256(-delta0)); + manager.settle(key.currency0); + manager.take(key.currency1, payer, uint256(delta1)); + } + + function _addLiquidity( + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams memory params, + bytes calldata hookData + ) public { + IPoolManager(manager).modifyLiquidity(key, params, hookData); + address payer = abi.decode(hookData, (address)); + int256 delta0 = IPoolManager(manager).currencyDelta(address(this), key.currency0); + int256 delta1 = IPoolManager(manager).currencyDelta(address(this), key.currency1); + + assert(delta0 < 0 || delta1 < 0); + assert(!(delta0 > 0 || delta1 > 0)); + + IERC20Minimal(Currency.unwrap(key.currency0)).transferFrom(payer, address(manager), uint256(-delta0)); + manager.settle(key.currency0); + IERC20Minimal(Currency.unwrap(key.currency1)).transferFrom(payer, address(manager), uint256(-delta1)); + manager.settle(key.currency1); + } + + function _removeLiquidity( + PoolKey calldata key, + IPoolManager.ModifyLiquidityParams memory params, + bytes calldata hookData + ) public { + // first hook needs to add liquidity for itself + IPoolManager.ModifyLiquidityParams memory newParams = + IPoolManager.ModifyLiquidityParams({tickLower: -120, tickUpper: 120, liquidityDelta: 1e18}); + IPoolManager(manager).modifyLiquidity(key, newParams, hookData); + // hook removes liquidity + IPoolManager(manager).modifyLiquidity(key, params, hookData); + address payer = abi.decode(hookData, (address)); + int256 delta0 = IPoolManager(manager).currencyDelta(address(this), key.currency0); + int256 delta1 = IPoolManager(manager).currencyDelta(address(this), key.currency1); + + assert(delta0 < 0 || delta1 < 0); + assert(!(delta0 > 0 || delta1 > 0)); + + IERC20Minimal(Currency.unwrap(key.currency0)).transferFrom(payer, address(manager), uint256(-delta0)); + manager.settle(key.currency0); + IERC20Minimal(Currency.unwrap(key.currency1)).transferFrom(payer, address(manager), uint256(-delta1)); + manager.settle(key.currency1); + } + + function _donate(PoolKey calldata key, uint256 amt0, uint256 amt1, bytes calldata hookData) public { + IPoolManager(manager).donate(key, amt0, amt1, hookData); + address payer = abi.decode(hookData, (address)); + int256 delta0 = IPoolManager(manager).currencyDelta(address(this), key.currency0); + int256 delta1 = IPoolManager(manager).currencyDelta(address(this), key.currency1); + IERC20Minimal(Currency.unwrap(key.currency0)).transferFrom(payer, address(manager), uint256(-delta0)); + IERC20Minimal(Currency.unwrap(key.currency1)).transferFrom(payer, address(manager), uint256(-delta1)); + manager.settle(key.currency0); + manager.settle(key.currency1); + } +} diff --git a/src/types/BalanceDelta.sol b/src/types/BalanceDelta.sol index 7a6e53406..49ae295e0 100644 --- a/src/types/BalanceDelta.sol +++ b/src/types/BalanceDelta.sol @@ -9,8 +9,7 @@ using BalanceDeltaLibrary for BalanceDelta global; function toBalanceDelta(int128 _amount0, int128 _amount1) pure returns (BalanceDelta balanceDelta) { /// @solidity memory-safe-assembly assembly { - balanceDelta := - or(shl(128, _amount0), and(0x00000000000000000000000000000000ffffffffffffffffffffffffffffffff, _amount1)) + balanceDelta := or(shl(128, _amount0), and(0xffffffffffffffffffffffffffffffff, _amount1)) } } diff --git a/src/types/PoolId.sol b/src/types/PoolId.sol index 0f3c438be..88387a8d9 100644 --- a/src/types/PoolId.sol +++ b/src/types/PoolId.sol @@ -7,7 +7,11 @@ type PoolId is bytes32; /// @notice Library for computing the ID of a pool library PoolIdLibrary { - function toId(PoolKey memory poolKey) internal pure returns (PoolId) { - return PoolId.wrap(keccak256(abi.encode(poolKey))); + /// @notice Returns value equal to keccak256(abi.encode(poolKey)) + function toId(PoolKey memory poolKey) internal pure returns (PoolId poolId) { + /// @solidity memory-safe-assembly + assembly { + poolId := keccak256(poolKey, mul(32, 5)) + } } } diff --git a/test/ERC6909Claims.t.sol b/test/ERC6909Claims.t.sol index fc01cde5a..dbf3dc478 100644 --- a/test/ERC6909Claims.t.sol +++ b/test/ERC6909Claims.t.sol @@ -31,7 +31,11 @@ contract ERC6909ClaimsTest is Test { if (mintAmount == type(uint256).max) { assertEq(token.allowance(sender, address(this), id), type(uint256).max); } else { - assertEq(token.allowance(sender, address(this), id), mintAmount - transferAmount); + if (sender != address(this)) { + assertEq(token.allowance(sender, address(this), id), mintAmount - transferAmount); + } else { + assertEq(token.allowance(sender, address(this), id), mintAmount); + } } assertEq(token.balanceOf(sender, id), mintAmount - transferAmount); } @@ -255,7 +259,11 @@ contract ERC6909ClaimsTest is Test { if (mintAmount == type(uint256).max) { assertEq(token.allowance(sender, address(this), id), type(uint256).max); } else { - assertEq(token.allowance(sender, address(this), id), mintAmount - transferAmount); + if (sender != address(this)) { + assertEq(token.allowance(sender, address(this), id), mintAmount - transferAmount); + } else { + assertEq(token.allowance(sender, address(this), id), mintAmount); + } } if (sender == receiver) { @@ -367,6 +375,7 @@ contract ERC6909ClaimsTest is Test { token.mint(sender, id, amount); + vm.assume(sender != address(this)); token.transferFrom(sender, receiver, id, amount); } } diff --git a/test/Owned.t.sol b/test/Owned.t.sol deleted file mode 100644 index 9bbeda0bf..000000000 --- a/test/Owned.t.sol +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.20; - -import {Test} from "forge-std/Test.sol"; -import {Vm} from "forge-std/Vm.sol"; -import {Owned} from "../src/Owned.sol"; - -contract OwnedTest is Test { - Owned owned; - - function testConstructor(address owner) public { - deployOwnedWithOwner(owner); - - assertEq(owner, owned.owner()); - } - - function testSetOwnerFromOwner(address oldOwner, address nextOwner) public { - // set the old owner as the owner - deployOwnedWithOwner(oldOwner); - - // old owner passes over ownership - vm.prank(oldOwner); - owned.setOwner(nextOwner); - assertEq(nextOwner, owned.owner()); - } - - function testSetOwnerFromNonOwner(address oldOwner, address nextOwner) public { - // set the old owner as the owner - deployOwnedWithOwner(oldOwner); - - if (oldOwner != nextOwner) { - vm.startPrank(nextOwner); - vm.expectRevert(Owned.InvalidCaller.selector); - owned.setOwner(nextOwner); - vm.stopPrank(); - } - } - - function deployOwnedWithOwner(address owner) internal { - vm.prank(owner); - owned = new Owned(); - } -} diff --git a/test/PoolManager.t.sol b/test/PoolManager.t.sol index 65bdf04c0..00f6452ba 100644 --- a/test/PoolManager.t.sol +++ b/test/PoolManager.t.sol @@ -8,7 +8,6 @@ import {IPoolManager} from "../src/interfaces/IPoolManager.sol"; import {IProtocolFees} from "../src/interfaces/IProtocolFees.sol"; import {IProtocolFeeController} from "../src/interfaces/IProtocolFeeController.sol"; import {PoolManager} from "../src/PoolManager.sol"; -import {Owned} from "../src/Owned.sol"; import {TickMath} from "../src/libraries/TickMath.sol"; import {Pool} from "../src/libraries/Pool.sol"; import {Deployers} from "./utils/Deployers.sol"; @@ -30,12 +29,15 @@ import {Position} from "../src/libraries/Position.sol"; import {Constants} from "./utils/Constants.sol"; import {SafeCast} from "../src/libraries/SafeCast.sol"; import {AmountHelpers} from "./utils/AmountHelpers.sol"; +import {ProtocolFeeLibrary} from "../src/libraries/ProtocolFeeLibrary.sol"; +import {IProtocolFees} from "../src/interfaces/IProtocolFees.sol"; contract PoolManagerTest is Test, Deployers, GasSnapshot { using Hooks for IHooks; using PoolIdLibrary for PoolKey; using SwapFeeLibrary for uint24; using CurrencyLibrary for Currency; + using ProtocolFeeLibrary for uint24; event UnlockCallback(); event ProtocolFeeControllerUpdated(address feeController); @@ -52,13 +54,15 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { int24 tick, uint24 fee ); - event ProtocolFeeUpdated(PoolId indexed id, uint16 protocolFee); + event Transfer( address caller, address indexed sender, address indexed receiver, uint256 indexed id, uint256 amount ); PoolEmptyUnlockTest emptyUnlockRouter; + uint24 constant MAX_FEE_BOTH_TOKENS = (2500 << 12) | 2500; // 2500 2500 + function setUp() public { initializeManagerRoutersAndPoolsWithLiq(IHooks(address(0))); @@ -69,7 +73,7 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { snapSize("PoolManager.bytecodeSize", address(manager)); } - function test_feeControllerSet() public { + function test_setProtocolFeeController_succeeds() public { deployFreshManager(); assertEq(address(manager.protocolFeeController()), address(0)); vm.expectEmit(false, false, false, true, address(manager)); @@ -78,6 +82,16 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { assertEq(address(manager.protocolFeeController()), address(feeController)); } + function test_setProtocolFeeController_failsIfNotOwner() public { + deployFreshManager(); + assertEq(address(manager.protocolFeeController()), address(0)); + + vm.prank(address(1)); // not the owner address + vm.expectRevert("UNAUTHORIZED"); + manager.setProtocolFeeController(feeController); + assertEq(address(manager.protocolFeeController()), address(0)); + } + function test_addLiquidity_failsIfNotInitialized() public { vm.expectRevert(Pool.PoolNotInitialized.selector); modifyLiquidityRouter.modifyLiquidity(uninitializedKey, LIQ_PARAMS, ZERO_BYTES); @@ -741,17 +755,17 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { snapEnd(); } - function test_swap_accruesProtocolFees(uint8 protocolFee1, uint8 protocolFee0) public { - protocolFee0 = uint8(bound(protocolFee0, 4, type(uint8).max)); - protocolFee1 = uint8(bound(protocolFee1, 4, type(uint8).max)); + function test_swap_accruesProtocolFees(uint16 protocolFee0, uint16 protocolFee1) public { + protocolFee0 = uint16(bound(protocolFee0, 1, 2500)); + protocolFee1 = uint16(bound(protocolFee1, 1, 2500)); - uint16 protocolFee = (uint16(protocolFee1) << 8) | (uint16(protocolFee0) & uint16(0xFF)); + uint24 protocolFee = (uint24(protocolFee1) << 12) | uint24(protocolFee0); - feeController.setSwapFeeForPool(key.toId(), protocolFee); + feeController.setProtocolFeeForPool(key.toId(), protocolFee); manager.setProtocolFee(key); - (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFee, protocolFee); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, protocolFee); // Add liquidity - Fees dont accrue for positive liquidity delta. IPoolManager.ModifyLiquidityParams memory params = LIQ_PARAMS; @@ -775,7 +789,7 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { swapRouter.swap(key, swapParams, PoolSwapTest.TestSettings(true, true, false), ZERO_BYTES); uint256 expectedTotalSwapFee = uint256(-swapParams.amountSpecified) * key.fee / 1e6; - uint256 expectedProtocolFee = expectedTotalSwapFee / protocolFee1; + uint256 expectedProtocolFee = expectedTotalSwapFee * protocolFee1 / 1e4; assertEq(manager.protocolFeesAccrued(currency0), 0); assertEq(manager.protocolFeesAccrued(currency1), expectedProtocolFee); } @@ -801,7 +815,7 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { // test successful donation if pool has liquidity function test_donate_succeedsWhenPoolHasLiquidity() public { - (, uint256 feeGrowthGlobal0X128, uint256 feeGrowthGlobal1X128,) = manager.pools(key.toId()); + (uint256 feeGrowthGlobal0X128, uint256 feeGrowthGlobal1X128) = manager.getFeeGrowthGlobals(key.toId()); assertEq(feeGrowthGlobal0X128, 0); assertEq(feeGrowthGlobal1X128, 0); @@ -809,19 +823,19 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { donateRouter.donate(key, 100, 200, ZERO_BYTES); snapEnd(); - (, feeGrowthGlobal0X128, feeGrowthGlobal1X128,) = manager.pools(key.toId()); + (feeGrowthGlobal0X128, feeGrowthGlobal1X128) = manager.getFeeGrowthGlobals(key.toId()); assertEq(feeGrowthGlobal0X128, 34028236692093846346337); assertEq(feeGrowthGlobal1X128, 68056473384187692692674); } function test_donate_succeedsForNativeTokensWhenPoolHasLiquidity() public { - (, uint256 feeGrowthGlobal0X128, uint256 feeGrowthGlobal1X128,) = manager.pools(nativeKey.toId()); + (uint256 feeGrowthGlobal0X128, uint256 feeGrowthGlobal1X128) = manager.getFeeGrowthGlobals(nativeKey.toId()); assertEq(feeGrowthGlobal0X128, 0); assertEq(feeGrowthGlobal1X128, 0); donateRouter.donate{value: 100}(nativeKey, 100, 200, ZERO_BYTES); - (, feeGrowthGlobal0X128, feeGrowthGlobal1X128,) = manager.pools(nativeKey.toId()); + (feeGrowthGlobal0X128, feeGrowthGlobal1X128) = manager.getFeeGrowthGlobals(nativeKey.toId()); assertEq(feeGrowthGlobal0X128, 34028236692093846346337); assertEq(feeGrowthGlobal1X128, 68056473384187692692674); } @@ -931,29 +945,29 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { manager.burn(address(this), key.currency0.toId(), 1); } - function test_setProtocolFee_updatesProtocolFeeForInitializedPool(uint16 protocolFee) public { - (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFee, 0); - feeController.setSwapFeeForPool(key.toId(), protocolFee); + function test_setProtocolFee_updatesProtocolFeeForInitializedPool(uint24 protocolFee) public { + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, 0); + feeController.setProtocolFeeForPool(key.toId(), protocolFee); - uint8 fee0 = uint8(protocolFee >> 8); - uint8 fee1 = uint8(protocolFee % 256); - if ((0 < fee0 && fee0 < 4) || (0 < fee1 && fee1 < 4)) { + uint16 fee0 = protocolFee.getZeroForOneFee(); + uint16 fee1 = protocolFee.getOneForZeroFee(); + if ((fee0 > 2500) || (fee1 > 2500)) { vm.expectRevert(IProtocolFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); manager.setProtocolFee(key); } else { vm.expectEmit(false, false, false, true); - emit ProtocolFeeUpdated(key.toId(), protocolFee); + emit IProtocolFees.ProtocolFeeUpdated(key.toId(), protocolFee); manager.setProtocolFee(key); - (slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFee, protocolFee); + (,, slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, protocolFee); } } function test_setProtocolFee_failsWithInvalidProtocolFeeControllers() public { - (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFee, 0); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, 0); manager.setProtocolFeeController(revertingFeeController); vm.expectRevert(IProtocolFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); @@ -973,34 +987,30 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { } function test_collectProtocolFees_initializesWithProtocolFeeIfCalled() public { - uint16 protocolFee = 1028; // 00000100 00000100 - - // sets the upper 12 bits - feeController.setSwapFeeForPool(uninitializedKey.toId(), uint16(protocolFee)); + feeController.setProtocolFeeForPool(uninitializedKey.toId(), MAX_FEE_BOTH_TOKENS); manager.initialize(uninitializedKey, SQRT_RATIO_1_1, ZERO_BYTES); - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFee, protocolFee); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0ProtocolFee, MAX_FEE_BOTH_TOKENS); } function test_collectProtocolFees_revertsIfCallerIsNotController() public { - vm.expectRevert(Owned.InvalidCaller.selector); + vm.expectRevert(IProtocolFees.InvalidCaller.selector); manager.collectProtocolFees(address(1), currency0, 0); } function test_collectProtocolFees_ERC20_accumulateFees_gas() public { - uint16 protocolFee = 1028; // 00000100 00000100 uint256 expectedFees = 7; - feeController.setSwapFeeForPool(key.toId(), uint16(protocolFee)); + feeController.setProtocolFeeForPool(key.toId(), MAX_FEE_BOTH_TOKENS); manager.setProtocolFee(key); - (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFee, protocolFee); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, MAX_FEE_BOTH_TOKENS); swapRouter.swap( key, - IPoolManager.SwapParams(true, 10000, SQRT_RATIO_1_2), + IPoolManager.SwapParams(true, -10000, SQRT_RATIO_1_2), PoolSwapTest.TestSettings(true, true, false), ZERO_BYTES ); @@ -1017,42 +1027,40 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { } function test_collectProtocolFees_ERC20_returnsAllFeesIf0IsProvidedAsParameter() public { - uint16 protocolFee = 1028; // 00000100 00000100 uint256 expectedFees = 7; - feeController.setSwapFeeForPool(key.toId(), uint16(protocolFee)); + feeController.setProtocolFeeForPool(key.toId(), MAX_FEE_BOTH_TOKENS); manager.setProtocolFee(key); - (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFee, protocolFee); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(key.toId()); + assertEq(slot0ProtocolFee, MAX_FEE_BOTH_TOKENS); swapRouter.swap( key, - IPoolManager.SwapParams(true, 10000, SQRT_RATIO_1_2), + IPoolManager.SwapParams(false, -10000, TickMath.MAX_SQRT_RATIO - 1), PoolSwapTest.TestSettings(true, true, false), ZERO_BYTES ); - assertEq(manager.protocolFeesAccrued(currency0), expectedFees); - assertEq(manager.protocolFeesAccrued(currency1), 0); - assertEq(currency0.balanceOf(address(1)), 0); - vm.prank(address(feeController)); - manager.collectProtocolFees(address(1), currency0, 0); - assertEq(currency0.balanceOf(address(1)), expectedFees); assertEq(manager.protocolFeesAccrued(currency0), 0); + assertEq(manager.protocolFeesAccrued(currency1), expectedFees); + assertEq(currency1.balanceOf(address(1)), 0); + vm.prank(address(feeController)); + manager.collectProtocolFees(address(1), currency1, 0); + assertEq(currency1.balanceOf(address(1)), expectedFees); + assertEq(manager.protocolFeesAccrued(currency1), 0); } function test_collectProtocolFees_nativeToken_accumulateFees_gas() public { - uint16 protocolFee = 1028; // 00000100 00000100 uint256 expectedFees = 7; Currency nativeCurrency = CurrencyLibrary.NATIVE; // set protocol fee before initializing the pool as it is fetched on initialization - feeController.setSwapFeeForPool(nativeKey.toId(), uint16(protocolFee)); + feeController.setProtocolFeeForPool(nativeKey.toId(), MAX_FEE_BOTH_TOKENS); manager.setProtocolFee(nativeKey); - (Pool.Slot0 memory slot0,,,) = manager.pools(nativeKey.toId()); - assertEq(slot0.protocolFee, protocolFee); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId()); + assertEq(slot0ProtocolFee, MAX_FEE_BOTH_TOKENS); swapRouter.swap{value: 10000}( nativeKey, @@ -1073,15 +1081,14 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { } function test_collectProtocolFees_nativeToken_returnsAllFeesIf0IsProvidedAsParameter() public { - uint16 protocolFee = 1028; // 00000100 00000100 uint256 expectedFees = 7; Currency nativeCurrency = CurrencyLibrary.NATIVE; - feeController.setSwapFeeForPool(nativeKey.toId(), uint16(protocolFee)); + feeController.setProtocolFeeForPool(nativeKey.toId(), MAX_FEE_BOTH_TOKENS); manager.setProtocolFee(nativeKey); - (Pool.Slot0 memory slot0,,,) = manager.pools(nativeKey.toId()); - assertEq(slot0.protocolFee, protocolFee); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(nativeKey.toId()); + assertEq(slot0ProtocolFee, MAX_FEE_BOTH_TOKENS); swapRouter.swap{value: 10000}( nativeKey, diff --git a/test/PoolManagerInitialize.t.sol b/test/PoolManagerInitialize.t.sol index c7bc91716..c414579f5 100644 --- a/test/PoolManagerInitialize.t.sol +++ b/test/PoolManagerInitialize.t.sol @@ -21,11 +21,13 @@ import {PoolId, PoolIdLibrary} from "../src/types/PoolId.sol"; import {SwapFeeLibrary} from "../src/libraries/SwapFeeLibrary.sol"; import {ProtocolFeeControllerTest} from "../src/test/ProtocolFeeControllerTest.sol"; import {IProtocolFeeController} from "../src/interfaces/IProtocolFeeController.sol"; +import {ProtocolFeeLibrary} from "../src/libraries/ProtocolFeeLibrary.sol"; contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { using Hooks for IHooks; using PoolIdLibrary for PoolKey; using SwapFeeLibrary for uint24; + using ProtocolFeeLibrary for uint24; event Initialize( PoolId indexed poolId, @@ -78,9 +80,9 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { emit Initialize(key0.toId(), key0.currency0, key0.currency1, key0.fee, key0.tickSpacing, key0.hooks); manager.initialize(key0, sqrtPriceX96, ZERO_BYTES); - (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); - assertEq(slot0.sqrtPriceX96, sqrtPriceX96); - assertEq(slot0.protocolFee, 0); + (uint160 slot0SqrtPriceX96,, uint24 slot0ProtocolFee,) = manager.getSlot0(key0.toId()); + assertEq(slot0SqrtPriceX96, sqrtPriceX96); + assertEq(slot0ProtocolFee, 0); } } @@ -100,10 +102,11 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { ); manager.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.sqrtPriceX96, sqrtPriceX96); - assertEq(slot0.protocolFee, 0); - assertEq(slot0.tick, TickMath.getTickAtSqrtRatio(sqrtPriceX96)); + (uint160 slot0SqrtPriceX96, int24 slot0Tick, uint24 slot0ProtocolFee,) = + manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0SqrtPriceX96, sqrtPriceX96); + assertEq(slot0ProtocolFee, 0); + assertEq(slot0Tick, TickMath.getTickAtSqrtRatio(sqrtPriceX96)); } function test_initialize_succeedsWithHooks(uint160 sqrtPriceX96) public { @@ -122,8 +125,8 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { uninitializedKey.hooks = IHooks(mockAddr); int24 tick = manager.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.sqrtPriceX96, sqrtPriceX96, "sqrtPrice"); + (uint160 slot0SqrtPriceX96,,,) = manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0SqrtPriceX96, sqrtPriceX96, "sqrtPrice"); bytes32 beforeSelector = MockHooks.beforeInitialize.selector; bytes memory beforeParams = abi.encode(address(this), uninitializedKey, sqrtPriceX96, ZERO_BYTES); @@ -169,8 +172,8 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { uninitializedKey.hooks = mockHooks; manager.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.sqrtPriceX96, sqrtPriceX96); + (uint160 slot0SqrtPriceX96,,,) = manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0SqrtPriceX96, sqrtPriceX96); } function test_initialize_revertsWithIdenticalTokens(uint160 sqrtPriceX96) public { @@ -195,21 +198,21 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { manager.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); } - function test_initialize_fetchFeeWhenController(uint16 protocolFee) public { + function test_initialize_fetchFeeWhenController(uint24 protocolFee) public { manager.setProtocolFeeController(feeController); - feeController.setSwapFeeForPool(uninitializedKey.toId(), protocolFee); + feeController.setProtocolFeeForPool(uninitializedKey.toId(), protocolFee); - uint8 fee0 = uint8(protocolFee >> 8); - uint8 fee1 = uint8(protocolFee % 256); + uint16 fee0 = protocolFee.getZeroForOneFee(); + uint16 fee1 = protocolFee.getOneForZeroFee(); manager.initialize(uninitializedKey, SQRT_RATIO_1_1, ZERO_BYTES); - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.sqrtPriceX96, SQRT_RATIO_1_1); - if ((0 < fee0 && fee0 < 4) || (0 < fee1 && fee1 < 4)) { - assertEq(slot0.protocolFee, 0); + (uint160 slot0SqrtPriceX96,, uint24 slot0ProtocolFee,) = manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0SqrtPriceX96, SQRT_RATIO_1_1); + if ((fee0 > 2500) || (fee1 > 2500)) { + assertEq(slot0ProtocolFee, 0); } else { - assertEq(slot0.protocolFee, protocolFee); + assertEq(slot0ProtocolFee, protocolFee); } } @@ -316,8 +319,8 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { ); manager.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); // protocol fees should default to 0 - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFee, 0); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0ProtocolFee, 0); // call to setProtocolFee should also revert vm.expectRevert(IProtocolFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); manager.setProtocolFee(uninitializedKey); @@ -340,8 +343,8 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { ); manager.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); // protocol fees should default to 0 - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFee, 0); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0ProtocolFee, 0); } function test_initialize_succeedsWithOverflowFeeController(uint160 sqrtPriceX96) public { @@ -361,8 +364,8 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { ); manager.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); // protocol fees should default to 0 - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFee, 0); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0ProtocolFee, 0); } function test_initialize_succeedsWithWrongReturnSizeFeeController(uint160 sqrtPriceX96) public { @@ -382,8 +385,8 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { ); manager.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); // protocol fees should default to 0 - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFee, 0); + (,, uint24 slot0ProtocolFee,) = manager.getSlot0(uninitializedKey.toId()); + assertEq(slot0ProtocolFee, 0); } function test_initialize_gas() public { diff --git a/test/SkipCallsTestHook.t.sol b/test/SkipCallsTestHook.t.sol new file mode 100644 index 000000000..1fc8ba984 --- /dev/null +++ b/test/SkipCallsTestHook.t.sol @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; +import {Vm} from "forge-std/Vm.sol"; +import {PoolId, PoolIdLibrary} from "../src/types/PoolId.sol"; +import {Hooks} from "../src/libraries/Hooks.sol"; +import {SwapFeeLibrary} from "../src/libraries/SwapFeeLibrary.sol"; +import {IPoolManager} from "../src/interfaces/IPoolManager.sol"; +import {IProtocolFees} from "../src/interfaces/IProtocolFees.sol"; +import {IHooks} from "../src/interfaces/IHooks.sol"; +import {PoolKey} from "../src/types/PoolKey.sol"; +import {PoolManager} from "../src/PoolManager.sol"; +import {PoolSwapTest} from "../src/test/PoolSwapTest.sol"; +import {Deployers} from "./utils/Deployers.sol"; +import {GasSnapshot} from "forge-gas-snapshot/GasSnapshot.sol"; +import {Currency, CurrencyLibrary} from "../src/types/Currency.sol"; +import {MockERC20} from "solmate/test/utils/mocks/MockERC20.sol"; +import {Constants} from "../test/utils/Constants.sol"; +import {SkipCallsTestHook} from "../src/test/SkipCallsTestHook.sol"; + +contract SkipCallsTest is Test, Deployers, GasSnapshot { + using PoolIdLibrary for PoolKey; + + IPoolManager.SwapParams swapParams = + IPoolManager.SwapParams({zeroForOne: true, amountSpecified: -100, sqrtPriceLimitX96: SQRT_RATIO_1_2}); + + PoolSwapTest.TestSettings testSettings = + PoolSwapTest.TestSettings({withdrawTokens: true, settleUsingTransfer: true, currencyAlreadySent: false}); + + uint160 clearAllHookPermisssionsMask; + uint256 hookPermissionCount = 10; + + function setUp() public { + clearAllHookPermisssionsMask = ~uint160(0) >> hookPermissionCount; + } + + function deploy(SkipCallsTestHook skipCallsTestHook) private { + SkipCallsTestHook impl = new SkipCallsTestHook(); + vm.etch(address(skipCallsTestHook), address(impl).code); + deployFreshManagerAndRouters(); + skipCallsTestHook.setManager(IPoolManager(manager)); + (currency0, currency1) = deployMintAndApprove2Currencies(); + + assertEq(skipCallsTestHook.counter(), 0); + + (key,) = initPool(currency0, currency1, IHooks(address(skipCallsTestHook)), 3000, SQRT_RATIO_1_1, ZERO_BYTES); + } + + function approveAndAddLiquidity(SkipCallsTestHook skipCallsTestHook) private { + MockERC20(Currency.unwrap(key.currency0)).approve(address(skipCallsTestHook), Constants.MAX_UINT256); + MockERC20(Currency.unwrap(key.currency1)).approve(address(skipCallsTestHook), Constants.MAX_UINT256); + modifyLiquidityRouter.modifyLiquidity(key, LIQ_PARAMS, abi.encode(address(this))); + } + + function test_beforeInitialize_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.BEFORE_INITIALIZE_FLAG)) + ); + + // initializes pool and increments counter + deploy(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 1); + } + + function test_afterInitialize_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.AFTER_INITIALIZE_FLAG)) + ); + + // initializes pool and increments counter + deploy(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 1); + } + + function test_beforeAddLiquidity_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.BEFORE_ADD_LIQUIDITY_FLAG)) + ); + + deploy(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // adds liquidity and increments counter + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 1); + // adds liquidity again and increments counter + modifyLiquidityRouter.modifyLiquidity(key, LIQ_PARAMS, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 2); + } + + function test_afterAddLiquidity_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.AFTER_ADD_LIQUIDITY_FLAG)) + ); + + deploy(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // adds liquidity and increments counter + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 1); + // adds liquidity and increments counter again + modifyLiquidityRouter.modifyLiquidity(key, LIQ_PARAMS, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 2); + } + + function test_beforeRemoveLiquidity_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.BEFORE_REMOVE_LIQUIDITY_FLAG)) + ); + + deploy(skipCallsTestHook); + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // removes liquidity and increments counter + modifyLiquidityRouter.modifyLiquidity(key, REMOVE_LIQ_PARAMS, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 1); + // adds liquidity again + modifyLiquidityRouter.modifyLiquidity(key, LIQ_PARAMS, abi.encode(address(this))); + // removes liquidity again and increments counter + modifyLiquidityRouter.modifyLiquidity(key, REMOVE_LIQ_PARAMS, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 2); + } + + function test_afterRemoveLiquidity_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.AFTER_REMOVE_LIQUIDITY_FLAG)) + ); + + deploy(skipCallsTestHook); + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // removes liquidity and increments counter + modifyLiquidityRouter.modifyLiquidity(key, REMOVE_LIQ_PARAMS, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 1); + // adds liquidity again + modifyLiquidityRouter.modifyLiquidity(key, LIQ_PARAMS, abi.encode(address(this))); + // removes liquidity again and increments counter + modifyLiquidityRouter.modifyLiquidity(key, REMOVE_LIQ_PARAMS, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 2); + } + + function test_beforeSwap_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.BEFORE_SWAP_FLAG)) + ); + + deploy(skipCallsTestHook); + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // swaps and increments counter + swapRouter.swap(key, swapParams, testSettings, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 1); + // swaps again and increments counter + swapRouter.swap(key, swapParams, testSettings, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 2); + } + + function test_gas_beforeSwap_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.BEFORE_SWAP_FLAG)) + ); + + deploy(skipCallsTestHook); + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // swaps and increments counter + snapStart("SkipCallsTestsHook.swap_skipsHookCallifHookIsCaller"); + swapRouter.swap(key, swapParams, testSettings, abi.encode(address(this))); + snapEnd(); + assertEq(skipCallsTestHook.counter(), 1); + } + + function test_afterSwap_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.AFTER_SWAP_FLAG)) + ); + + deploy(skipCallsTestHook); + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // swaps and increments counter + swapRouter.swap(key, swapParams, testSettings, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 1); + // swaps again and increments counter + swapRouter.swap(key, swapParams, testSettings, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 2); + } + + function test_beforeDonate_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.BEFORE_DONATE_FLAG)) + ); + + deploy(skipCallsTestHook); + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // donates and increments counter + donateRouter.donate(key, 100, 200, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 1); + // donates again and increments counter + donateRouter.donate(key, 100, 200, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 2); + } + + function test_afterDonate_skipIfCalledByHook() public { + SkipCallsTestHook skipCallsTestHook = SkipCallsTestHook( + address(uint160(type(uint160).max & clearAllHookPermisssionsMask | Hooks.AFTER_DONATE_FLAG)) + ); + + deploy(skipCallsTestHook); + approveAndAddLiquidity(skipCallsTestHook); + assertEq(skipCallsTestHook.counter(), 0); + + // donates and increments counter + donateRouter.donate(key, 100, 200, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 1); + // donates again and increments counter + donateRouter.donate(key, 100, 200, abi.encode(address(this))); + assertEq(skipCallsTestHook.counter(), 2); + } +} diff --git a/test/Tick.t.sol b/test/Tick.t.sol index 565572ddb..b99d3a128 100644 --- a/test/Tick.t.sol +++ b/test/Tick.t.sol @@ -495,7 +495,8 @@ contract TickTest is Test, GasSnapshot { assertEq((maxTick - minTick) % tickSpacing, 0); uint256 numTicks = uint256(int256((maxTick - minTick) / tickSpacing)) + 1; - // max liquidity at every tick is less than the cap - assertGt(type(uint128).max, uint256(maxLiquidityPerTick) * numTicks); + + // sum of max liquidity on each tick is at most the cap + assertGe(type(uint128).max, uint256(maxLiquidityPerTick) * numTicks); } } diff --git a/test/libraries/NonZeroDeltaCount.t.sol b/test/libraries/NonZeroDeltaCount.t.sol index 5dbcf34b0..523324f01 100644 --- a/test/libraries/NonZeroDeltaCount.t.sol +++ b/test/libraries/NonZeroDeltaCount.t.sol @@ -5,16 +5,16 @@ import {Test} from "forge-std/Test.sol"; import {NonZeroDeltaCount} from "src/libraries/NonZeroDeltaCount.sol"; contract NonZeroDeltaCountTest is Test { - address constant ADDRESS_AS = 0xaAaAaAaaAaAaAaaAaAAAAAAAAaaaAaAaAaaAaaAa; - address constant ADDRESS_BS = 0xbBbBBBBbbBBBbbbBbbBbbbbBBbBbbbbBbBbbBBbB; - function test_incrementNonzeroDeltaCount() public { + assertEq(NonZeroDeltaCount.read(), 0); NonZeroDeltaCount.increment(); assertEq(NonZeroDeltaCount.read(), 1); } function test_decrementNonzeroDeltaCount() public { + assertEq(NonZeroDeltaCount.read(), 0); NonZeroDeltaCount.increment(); + assertEq(NonZeroDeltaCount.read(), 1); NonZeroDeltaCount.decrement(); assertEq(NonZeroDeltaCount.read(), 0); } @@ -22,6 +22,7 @@ contract NonZeroDeltaCountTest is Test { // Reading from right to left. Bit of 0: call increase. Bit of 1: call decrease. // The library allows over over/underflow so we dont check for that here function test_fuzz_nonZeroDeltaCount(uint256 instructions) public { + assertEq(NonZeroDeltaCount.read(), 0); uint256 expectedCount; for (uint256 i = 0; i < 256; i++) { if ((instructions & (1 << i)) == 0) { diff --git a/test/libraries/Pool.t.sol b/test/libraries/Pool.t.sol index 6874038a1..2af3df6a7 100644 --- a/test/libraries/Pool.t.sol +++ b/test/libraries/Pool.t.sol @@ -121,9 +121,9 @@ contract PoolTest is Test { state.swap(params); if (params.zeroForOne) { - assertLe(state.slot0.sqrtPriceX96, params.sqrtPriceLimitX96); - } else { assertGe(state.slot0.sqrtPriceX96, params.sqrtPriceLimitX96); + } else { + assertLe(state.slot0.sqrtPriceX96, params.sqrtPriceLimitX96); } } } diff --git a/test/libraries/PoolId.t.sol b/test/libraries/PoolId.t.sol new file mode 100644 index 000000000..fc8fc48d2 --- /dev/null +++ b/test/libraries/PoolId.t.sol @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; +import {PoolId, PoolIdLibrary} from "src/types/PoolId.sol"; +import {PoolKey} from "src/types/PoolKey.sol"; + +contract PoolIdTest is Test { + using PoolIdLibrary for PoolKey; + + function test_fuzz_toId(PoolKey memory poolKey) public { + bytes memory encodedKey = abi.encode(poolKey); + bytes32 expectedHash = keccak256(encodedKey); + assertEq(PoolId.unwrap(poolKey.toId()), expectedHash, "hashes not equal"); + } +} diff --git a/test/libraries/Position.t.sol b/test/libraries/Position.t.sol new file mode 100644 index 000000000..cda1da7ca --- /dev/null +++ b/test/libraries/Position.t.sol @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +import {Test} from "forge-std/Test.sol"; +import {Position} from "src/libraries/Position.sol"; + +contract PositionTest is Test { + using Position for mapping(bytes32 => Position.Info); + + mapping(bytes32 => Position.Info) internal positions; + + function test_get_fuzz(address owner, int24 tickLower, int24 tickUpper) public { + bytes32 positionKey = keccak256(abi.encodePacked(owner, tickLower, tickUpper)); + Position.Info storage expectedPosition = positions[positionKey]; + Position.Info storage position = positions.get(owner, tickLower, tickUpper); + bytes32 expectedPositionSlot; + bytes32 positionSlot; + assembly ("memory-safe") { + expectedPositionSlot := expectedPosition.slot + positionSlot := position.slot + } + assertEq(positionSlot, expectedPositionSlot, "slots not equal"); + } +} diff --git a/test/libraries/SwapFeeLibrary.t.sol b/test/libraries/SwapFeeLibrary.t.sol new file mode 100644 index 000000000..a1976e1dd --- /dev/null +++ b/test/libraries/SwapFeeLibrary.t.sol @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +import "src/libraries/SwapFeeLibrary.sol"; +import "forge-std/Test.sol"; + +contract SwapFeeLibraryTest is Test { + function test_isDynamicFee_returnsTrue() public { + uint24 dynamicFee = 0x800000; + assertTrue(SwapFeeLibrary.isDynamicFee(dynamicFee)); + } + + function test_isDynamicFee_returnsTrue_forMaxValue() public { + uint24 dynamicFee = 0xFFFFFF; + assertTrue(SwapFeeLibrary.isDynamicFee(dynamicFee)); + } + + function test_isDynamicFee_returnsFalse() public { + uint24 dynamicFee = 0x7FFFFF; + assertFalse(SwapFeeLibrary.isDynamicFee(dynamicFee)); + } + + function test_fuzz_isDynamicFee(uint24 fee) public { + assertEq((fee >> 23 == 1), SwapFeeLibrary.isDynamicFee(fee)); + } + + function test_validate_doesNotRevertWithNoFee() public pure { + uint24 fee = 0; + SwapFeeLibrary.validate(fee); + } + + function test_validate_doesNotRevert() public pure { + uint24 fee = 500000; // 50% + SwapFeeLibrary.validate(fee); + } + + function test_validate_doesNotRevertWithMaxFee() public pure { + uint24 maxFee = 1000000; // 100% + SwapFeeLibrary.validate(maxFee); + } + + function test_validate_revertsWithFeeTooLarge() public { + uint24 fee = 1000001; + vm.expectRevert(SwapFeeLibrary.FeeTooLarge.selector); + SwapFeeLibrary.validate(fee); + } + + function test_fuzz_validate(uint24 fee) public { + if (fee > 1000000) { + vm.expectRevert(SwapFeeLibrary.FeeTooLarge.selector); + } + SwapFeeLibrary.validate(fee); + } + + function test_getInitialSwapFee_forStaticFeeIsCorrect() public { + uint24 staticFee = 3000; // 30 bps + assertEq(SwapFeeLibrary.getInitialSwapFee(staticFee), staticFee); + } + + function test_getInitialSwapFee_revertsWithFeeTooLarge_forStaticFee() public { + uint24 staticFee = 1000001; + vm.expectRevert(SwapFeeLibrary.FeeTooLarge.selector); + SwapFeeLibrary.getInitialSwapFee(staticFee); + } + + function test_getInitialSwapFee_forDynamicFeeIsZero() public { + uint24 dynamicFee = 0x800BB8; + assertEq(SwapFeeLibrary.getInitialSwapFee(dynamicFee), 0); + } + + function test_fuzz_getInitialSwapFee(uint24 fee) public { + if (fee >> 23 == 1) { + assertEq(SwapFeeLibrary.getInitialSwapFee(fee), 0); + } else if (fee > 1000000) { + vm.expectRevert(SwapFeeLibrary.FeeTooLarge.selector); + SwapFeeLibrary.getInitialSwapFee(fee); + } else { + assertEq(SwapFeeLibrary.getInitialSwapFee(fee), fee); + } + } +} diff --git a/test/types/BalanceDelta.t.sol b/test/types/BalanceDelta.t.sol index aafe83f2e..adfa9bec3 100644 --- a/test/types/BalanceDelta.t.sol +++ b/test/types/BalanceDelta.t.sol @@ -2,10 +2,10 @@ pragma solidity ^0.8.20; import {Test} from "forge-std/Test.sol"; -import {BalanceDelta, toBalanceDelta} from "../../src/types/BalanceDelta.sol"; +import {BalanceDelta, toBalanceDelta} from "src/types/BalanceDelta.sol"; contract TestBalanceDelta is Test { - function testToBalanceDelta() public { + function test_toBalanceDelta() public { BalanceDelta balanceDelta = toBalanceDelta(0, 0); assertEq(balanceDelta.amount0(), 0); assertEq(balanceDelta.amount1(), 0); @@ -27,17 +27,52 @@ contract TestBalanceDelta is Test { assertEq(balanceDelta.amount1(), type(int128).min); } - function testToBalanceDelta(int128 x, int128 y) public { + function test_fuzz_toBalanceDelta(int128 x, int128 y) public { + BalanceDelta balanceDelta = toBalanceDelta(x, y); + int256 expectedBD = int256(uint256(bytes32(abi.encodePacked(x, y)))); + assertEq(BalanceDelta.unwrap(balanceDelta), expectedBD); + } + + function test_fuzz_amount0_amount1(int128 x, int128 y) public { BalanceDelta balanceDelta = toBalanceDelta(x, y); assertEq(balanceDelta.amount0(), x); assertEq(balanceDelta.amount1(), y); } - function testAdd(int128 a, int128 b, int128 c, int128 d) public { + function test_add() public { + BalanceDelta balanceDelta = toBalanceDelta(0, 0) + toBalanceDelta(0, 0); + assertEq(balanceDelta.amount0(), 0); + assertEq(balanceDelta.amount1(), 0); + + balanceDelta = toBalanceDelta(-1000, 1000) + toBalanceDelta(1000, -1000); + assertEq(balanceDelta.amount0(), 0); + assertEq(balanceDelta.amount1(), 0); + + balanceDelta = + toBalanceDelta(type(int128).min, type(int128).max) + toBalanceDelta(type(int128).max, type(int128).min); + assertEq(balanceDelta.amount0(), -1); + assertEq(balanceDelta.amount1(), -1); + + balanceDelta = toBalanceDelta(type(int128).max / 2 + 1, type(int128).max / 2 + 1) + + toBalanceDelta(type(int128).max / 2, type(int128).max / 2); + assertEq(balanceDelta.amount0(), type(int128).max); + assertEq(balanceDelta.amount1(), type(int128).max); + } + + function test_add_revertsOnOverflow() public { + // should revert because type(int128).max + 1 is not possible + vm.expectRevert(); + toBalanceDelta(type(int128).max, 0) + toBalanceDelta(1, 0); + + vm.expectRevert(); + toBalanceDelta(0, type(int128).max) + toBalanceDelta(0, 1); + } + + function test_fuzz_add(int128 a, int128 b, int128 c, int128 d) public { int256 ac = int256(a) + c; int256 bd = int256(b) + d; - // make sure the addition doesn't overflow + // if the addition overflows it should revert if (ac != int128(ac) || bd != int128(bd)) { vm.expectRevert(); } @@ -47,16 +82,52 @@ contract TestBalanceDelta is Test { assertEq(balanceDelta.amount1(), bd); } - function testSub(int128 a, int128 b, int128 c, int128 d) public { + function test_sub() public { + BalanceDelta balanceDelta = toBalanceDelta(0, 0) - toBalanceDelta(0, 0); + assertEq(balanceDelta.amount0(), 0); + assertEq(balanceDelta.amount1(), 0); + + balanceDelta = toBalanceDelta(-1000, 1000) - toBalanceDelta(1000, -1000); + assertEq(balanceDelta.amount0(), -2000); + assertEq(balanceDelta.amount1(), 2000); + + balanceDelta = + toBalanceDelta(-1000, -1000) - toBalanceDelta(-(type(int128).min + 1000), -(type(int128).min + 1000)); + assertEq(balanceDelta.amount0(), type(int128).min); + assertEq(balanceDelta.amount1(), type(int128).min); + + balanceDelta = toBalanceDelta(type(int128).min / 2, type(int128).min / 2) + - toBalanceDelta(-(type(int128).min / 2), -(type(int128).min / 2)); + assertEq(balanceDelta.amount0(), type(int128).min); + assertEq(balanceDelta.amount1(), type(int128).min); + } + + function test_sub_revertsOnUnderflow() public { + // should revert because type(int128).min - 1 is not possible + vm.expectRevert(); + toBalanceDelta(type(int128).min, 0) - toBalanceDelta(1, 0); + + vm.expectRevert(); + toBalanceDelta(0, type(int128).min) - toBalanceDelta(0, 1); + } + + function test_fuzz_sub(int128 a, int128 b, int128 c, int128 d) public { int256 ac = int256(a) - c; int256 bd = int256(b) - d; - // make sure the subtraction doesn't underflow - vm.assume(ac == int128(ac)); - vm.assume(bd == int128(bd)); + // if the subtraction underflows it should revert + if (ac != int128(ac) || bd != int128(bd)) { + vm.expectRevert(); + } BalanceDelta balanceDelta = toBalanceDelta(a, b) - toBalanceDelta(c, d); assertEq(balanceDelta.amount0(), ac); assertEq(balanceDelta.amount1(), bd); } + + function test_fuzz_eq(int128 a, int128 b, int128 c, int128 d) public { + bool isEqual = (toBalanceDelta(a, b) == toBalanceDelta(c, d)); + if (a == c && b == d) assertTrue(isEqual); + else assertFalse(isEqual); + } } diff --git a/test/utils/AmountHelpers.sol b/test/utils/AmountHelpers.sol index a956816ce..298fd3f77 100644 --- a/test/utils/AmountHelpers.sol +++ b/test/utils/AmountHelpers.sol @@ -3,7 +3,6 @@ pragma solidity ^0.8.20; import {LiquidityAmounts} from "./LiquidityAmounts.sol"; import {IPoolManager} from "../../src/interfaces/IPoolManager.sol"; -import {PoolManager} from "../../src/PoolManager.sol"; import {PoolId, PoolIdLibrary} from "../../src/types/PoolId.sol"; import {TickMath} from "../../src/libraries/TickMath.sol"; import {PoolKey} from "../../src/types/PoolKey.sol"; @@ -12,7 +11,7 @@ import {PoolKey} from "../../src/types/PoolKey.sol"; /// @notice Helps calculate amounts for bounding fuzz tests library AmountHelpers { function getMaxAmountInForPool( - PoolManager manager, + IPoolManager manager, IPoolManager.ModifyLiquidityParams memory params, PoolKey memory key ) public view returns (uint256 amount0, uint256 amount1) { @@ -24,6 +23,6 @@ library AmountHelpers { uint160 sqrtPriceX96Upper = TickMath.getSqrtRatioAtTick(params.tickUpper); amount0 = LiquidityAmounts.getAmount0ForLiquidity(sqrtPriceX96Lower, sqrtPriceX96, liquidity); - amount1 = LiquidityAmounts.getAmount0ForLiquidity(sqrtPriceX96Upper, sqrtPriceX96, liquidity); + amount1 = LiquidityAmounts.getAmount1ForLiquidity(sqrtPriceX96Upper, sqrtPriceX96, liquidity); } } diff --git a/test/utils/Deployers.sol b/test/utils/Deployers.sol index e87866101..13af1028c 100644 --- a/test/utils/Deployers.sol +++ b/test/utils/Deployers.sol @@ -52,7 +52,7 @@ contract Deployers { // Global variables Currency internal currency0; Currency internal currency1; - PoolManager manager; + IPoolManager manager; PoolModifyLiquidityTest modifyLiquidityRouter; PoolSwapTest swapRouter; PoolDonateTest donateRouter;