diff --git a/src/FoldCaptiveStaking.sol b/src/FoldCaptiveStaking.sol index 6af2cfb..ba95082 100644 --- a/src/FoldCaptiveStaking.sol +++ b/src/FoldCaptiveStaking.sol @@ -214,10 +214,7 @@ contract FoldCaptiveStaking is Owned(msg.sender) { function compound() public isInitialized { collectPositionFees(); - uint256 fee0Owed = (token0FeesPerLiquidity - balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount - / liquidityUnderManagement; - uint256 fee1Owed = (token1FeesPerLiquidity - balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount - / liquidityUnderManagement; + (uint256 fee0Owed, uint256 fee1Owed) = owedFees(); INonfungiblePositionManager.IncreaseLiquidityParams memory params = INonfungiblePositionManager .IncreaseLiquidityParams({ @@ -243,14 +240,23 @@ contract FoldCaptiveStaking is Owned(msg.sender) { emit Compounded(msg.sender, liquidity, fee0Owed, fee1Owed); } + /// @notice User-specific function to view fees owed on the singular position + function owedFees() public view returns (uint256, uint256) { + uint256 fee0Owed = ((token0FeesPerLiquidity - + balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount) / + liquidityUnderManagement; + uint256 fee1Owed = ((token1FeesPerLiquidity - + balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount) / + liquidityUnderManagement; + + return (fee0Owed, fee1Owed); + } + /// @notice User-specific function to collect fees on the singular position function collectFees() public isInitialized { collectPositionFees(); - uint256 fee0Owed = (token0FeesPerLiquidity - balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount - / liquidityUnderManagement; - uint256 fee1Owed = (token1FeesPerLiquidity - balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount - / liquidityUnderManagement; + (uint256 fee0Owed, uint256 fee1Owed) = owedFees(); token0.transfer(msg.sender, fee0Owed); token1.transfer(msg.sender, fee1Owed); @@ -261,10 +267,16 @@ contract FoldCaptiveStaking is Owned(msg.sender) { emit FeesCollected(msg.sender, fee0Owed, fee1Owed); } + /// @notice User-specific function to view rewards owed on the singular position + function owedRewards() public view returns (uint256) { + return + ((rewardsPerLiquidity - balances[msg.sender].rewardDebt) * + balances[msg.sender].amount) / liquidityUnderManagement; + } + /// @notice User-specific Rewards for Protocol Rewards function collectRewards() public isInitialized { - uint256 rewardsOwed = (rewardsPerLiquidity - balances[msg.sender].rewardDebt) * balances[msg.sender].amount - / liquidityUnderManagement; + uint256 rewardsOwed = owedRewards(); WETH9.transfer(msg.sender, rewardsOwed); @@ -373,7 +385,8 @@ contract FoldCaptiveStaking is Owned(msg.sender) { amount1Max: uint128(amount1) }); - (uint256 amount0Collected, uint256 amount1Collected) = positionManager.collect(collectParams); + (uint256 amount0Collected, uint256 amount1Collected) = positionManager + .collect(collectParams); if (amount0Collected != amount0 || amount1Collected != amount1) { revert WithdrawFailed(); diff --git a/test/UnitTests.t.sol b/test/UnitTests.t.sol index 3d8c969..49efd2b 100644 --- a/test/UnitTests.t.sol +++ b/test/UnitTests.t.sol @@ -51,6 +51,42 @@ contract UnitTests is BaseCaptiveTest { assertEq(amount, liq / 4); } + /// @dev Ensure that owed fees are returned correctly. + function testOwedFees() public { + testAddLiquidity(); + uint256 owedReards = foldCaptiveStaking.owedRewards(); + (uint256 amount, uint256 rewardDebt, , ) = foldCaptiveStaking.balances( + User01 + ); + uint256 rewardOwedCheck = ((foldCaptiveStaking.rewardsPerLiquidity() - + rewardDebt) * amount) / + foldCaptiveStaking.liquidityUnderManagement(); + + assertEq(rewardOwedCheck, owedReards); + } + + /// @dev Ensure that owed fees are returned correctly. + function testOwedRewards() public { + testAddLiquidity(); + (uint256 fee0Owed, uint256 fee1Owed) = foldCaptiveStaking.owedFees(); + ( + uint256 amount, + , + uint256 token0FeeDebt, + uint256 token1FeeDebt + ) = foldCaptiveStaking.balances(User01); + uint256 fee0OwedCheck = ((foldCaptiveStaking.token0FeesPerLiquidity() - + token0FeeDebt) * amount) / + foldCaptiveStaking.liquidityUnderManagement(); + + uint256 fee1OwedCheck = ((foldCaptiveStaking.token1FeesPerLiquidity() - + token1FeeDebt) * amount) / + foldCaptiveStaking.liquidityUnderManagement(); + + assertEq(fee0OwedCheck, fee0Owed); + assertEq(fee1OwedCheck, fee1Owed); + } + /// @dev Ensure fees are accrued correctly and distributed proportionately. function testFeesAccrue() public { testAddLiquidity();