From 8f4be930e79c7ff1c9466152e879816bbc662eda Mon Sep 17 00:00:00 2001 From: Emily Williams Date: Thu, 3 Oct 2024 12:53:25 -0400 Subject: [PATCH] fix(router-sdk): Prepare router-sdk to fully support mixed routes (#135) --- .../src/entities/mixedRoute/route.test.ts | 36 +++++++++++++++++++ .../src/entities/mixedRoute/route.ts | 19 +++++----- .../src/entities/mixedRoute/trade.ts | 16 +++++---- .../src/utils/encodeMixedRouteToPath.ts | 4 +-- sdks/router-sdk/src/utils/getOutputAmount.ts | 22 ------------ sdks/router-sdk/src/utils/index.ts | 6 ++-- sdks/router-sdk/src/utils/pathCurrency.ts | 36 +++++++++++++++++++ 7 files changed, 97 insertions(+), 42 deletions(-) delete mode 100644 sdks/router-sdk/src/utils/getOutputAmount.ts create mode 100644 sdks/router-sdk/src/utils/pathCurrency.ts diff --git a/sdks/router-sdk/src/entities/mixedRoute/route.test.ts b/sdks/router-sdk/src/entities/mixedRoute/route.test.ts index d650eee1d..0887147fe 100644 --- a/sdks/router-sdk/src/entities/mixedRoute/route.test.ts +++ b/sdks/router-sdk/src/entities/mixedRoute/route.test.ts @@ -78,6 +78,8 @@ describe('MixedRoute', () => { expect(route.path).toEqual([token1, token0, weth]) expect(route.input).toEqual(token1) expect(route.output).toEqual(weth) + expect(route.pathInput).toEqual(token1) + expect(route.pathOutput).toEqual(weth) expect(route.chainId).toEqual(1) }) @@ -87,6 +89,8 @@ describe('MixedRoute', () => { expect(route.path).toEqual([token0, weth, token1]) expect(route.input).toEqual(token0) expect(route.output).toEqual(token1) + expect(route.pathInput).toEqual(token0) + expect(route.pathOutput).toEqual(token1) expect(route.chainId).toEqual(1) }) @@ -125,6 +129,17 @@ describe('MixedRoute', () => { expect(route.chainId).toEqual(1) }) + it('wraps complex mixed route object that unwraps WETH to ETH at the end', () => { + const route = new MixedRouteSDK([pool_v3_0_1, pool_v3_1_weth], token0, ETHER) + expect(route.pools).toEqual([pool_v3_0_1, pool_v3_1_weth]) + expect(route.path).toEqual([token0, token1, weth]) + expect(route.input).toEqual(token0) + expect(route.output).toEqual(ETHER) + expect(route.pathInput).toEqual(token0) + expect(route.pathOutput).toEqual(weth) + expect(route.chainId).toEqual(1) + }) + it('wraps complex mixed route object with multihop V2 in the beginning and constructs a path', () => { const route = new MixedRouteSDK([pair_0_1, pair_1_weth, pool_v3_2_weth], token0, token2) expect(route.pools).toEqual([pair_0_1, pair_1_weth, pool_v3_2_weth]) @@ -463,6 +478,7 @@ describe('MixedRoute', () => { const route = new MixedRouteSDK([pool_v3_0_1], token0, token1) expect(partitionMixedRouteByProtocol(route)).toStrictEqual([[pool_v3_0_1]]) }) + it('returns correct for single pool', () => { const route = new MixedRouteSDK([pair_0_1], token0, token1) expect(partitionMixedRouteByProtocol(route)).toStrictEqual([[pair_0_1]]) @@ -492,6 +508,7 @@ describe('MixedRoute', () => { }) expect(result[2][0]).toStrictEqual(pool_v3_2_3) }) + it('consecutive pair at the end', () => { const route: MixedRouteSDK = new MixedRouteSDK( [pool_v3_0_1, pair_1_weth, pair_weth_2, pair_2_3], @@ -506,6 +523,7 @@ describe('MixedRoute', () => { expect(pair).toStrictEqual(referenceSecondPart[i]) }) }) + it('consecutive pair at the beginning', () => { const route: MixedRouteSDK = new MixedRouteSDK( [pair_0_1, pair_1_weth, pair_weth_2, pool_v3_2_3], @@ -520,5 +538,23 @@ describe('MixedRoute', () => { }) expect(result[1][0]).toStrictEqual(pool_v3_2_3) }) + + it('returns correct for route with V4Pool', () => { + const route: MixedRouteSDK = new MixedRouteSDK( + [pool_v4_0_1, pool_v4_1_eth, pair_weth_2, pool_v3_2_3], + token0, + token3 + ) + + const result = partitionMixedRouteByProtocol(route) + expect(result.length).toEqual(3) + expect(result[0].length).toEqual(2) + expect(result[1].length).toEqual(1) + expect(result[2].length).toEqual(1) + expect(result[0][0]).toStrictEqual(pool_v4_0_1) + expect(result[0][1]).toStrictEqual(pool_v4_1_eth) + expect(result[1][0]).toStrictEqual(pair_weth_2) + expect(result[2][0]).toStrictEqual(pool_v3_2_3) + }) }) }) diff --git a/sdks/router-sdk/src/entities/mixedRoute/route.ts b/sdks/router-sdk/src/entities/mixedRoute/route.ts index b61096b0a..f308d38fa 100644 --- a/sdks/router-sdk/src/entities/mixedRoute/route.ts +++ b/sdks/router-sdk/src/entities/mixedRoute/route.ts @@ -2,6 +2,7 @@ import invariant from 'tiny-invariant' import { Currency, Price, Token } from '@uniswap/sdk-core' import { Pool as V4Pool } from '@uniswap/v4-sdk' import { isValidTokenPath } from '../../utils/isValidTokenPath' +import { getPathCurrency } from '../../utils/pathCurrency' import { TPool } from '../../utils/TPool' /** @@ -14,7 +15,8 @@ export class MixedRouteSDK { public readonly path: Currency[] public readonly input: TInput public readonly output: TOutput - public readonly adjustedInput: Currency // routes with v2/v3 initial pool must wrap native input currency before trading + public readonly pathInput: Currency // routes may need to wrap/unwrap a currency to begin trading path + public readonly pathOutput: Currency // routes may need to wrap/unwrap a currency at the end of trading path private _midPrice: Price | null = null @@ -31,13 +33,10 @@ export class MixedRouteSDK { const allOnSameChain = pools.every((pool) => pool.chainId === chainId) invariant(allOnSameChain, 'CHAIN_IDS') - if (pools[0] instanceof V4Pool) { - this.adjustedInput = pools[0].involvesToken(input) ? input : input.wrapped - } else { - this.adjustedInput = input.wrapped // no native currencies in v2/v3 - } + this.pathInput = getPathCurrency(input, pools[0]) + this.pathOutput = getPathCurrency(output, pools[pools.length - 1]) - invariant(pools[0].involvesToken(this.adjustedInput as Token), 'INPUT') + invariant(pools[0].involvesToken(this.pathInput as Token), 'INPUT') const lastPool = pools[pools.length - 1] if (lastPool instanceof V4Pool) { invariant(lastPool.involvesToken(output) || lastPool.involvesToken(output.wrapped), 'OUTPUT') @@ -48,8 +47,8 @@ export class MixedRouteSDK { /** * Normalizes token0-token1 order and selects the next token/fee step to add to the path * */ - const tokenPath: Currency[] = [this.adjustedInput] - pools[0].token0.equals(this.adjustedInput) ? tokenPath.push(pools[0].token1) : tokenPath.push(pools[0].token0) + const tokenPath: Currency[] = [this.pathInput] + pools[0].token0.equals(this.pathInput) ? tokenPath.push(pools[0].token1) : tokenPath.push(pools[0].token0) for (let i = 1; i < pools.length; i++) { const prevPool = pools[i - 1] @@ -90,7 +89,7 @@ export class MixedRouteSDK { } }, - this.pools[0].token0.equals(this.adjustedInput) + this.pools[0].token0.equals(this.pathInput) ? { nextInput: this.pools[0].token1, price: this.pools[0].token0Price.asFraction, diff --git a/sdks/router-sdk/src/entities/mixedRoute/trade.ts b/sdks/router-sdk/src/entities/mixedRoute/trade.ts index 5405805ad..634437a0f 100644 --- a/sdks/router-sdk/src/entities/mixedRoute/trade.ts +++ b/sdks/router-sdk/src/entities/mixedRoute/trade.ts @@ -1,11 +1,11 @@ -import { Currency, Fraction, Percent, Price, sortedInsert, CurrencyAmount, TradeType } from '@uniswap/sdk-core' +import { Currency, Fraction, Percent, Price, sortedInsert, CurrencyAmount, TradeType, Token } from '@uniswap/sdk-core' import { Pair } from '@uniswap/v2-sdk' import { BestTradeOptions, Pool as V3Pool } from '@uniswap/v3-sdk' import { Pool as V4Pool } from '@uniswap/v4-sdk' import invariant from 'tiny-invariant' import { ONE, ZERO } from '../../constants' import { MixedRouteSDK } from './route' -import { getOutputAmount } from '../../utils/getOutputAmount' +import { amountWithPathCurrency } from '../../utils/pathCurrency' import { TPool } from '../../utils/TPool' /** @@ -203,10 +203,12 @@ export class MixedRouteTrade + ) amounts[i + 1] = outputAmount } @@ -255,11 +257,13 @@ export class MixedRouteTrade + ) amounts[i + 1] = outputAmount } diff --git a/sdks/router-sdk/src/utils/encodeMixedRouteToPath.ts b/sdks/router-sdk/src/utils/encodeMixedRouteToPath.ts index dabcf12b1..eca13741e 100644 --- a/sdks/router-sdk/src/utils/encodeMixedRouteToPath.ts +++ b/sdks/router-sdk/src/utils/encodeMixedRouteToPath.ts @@ -26,9 +26,9 @@ export function encodeMixedRouteToPath(route: MixedRouteSDK) let types: string[] if (containsV4Pool) { - path = [route.adjustedInput.isNative ? ADDRESS_ZERO : route.adjustedInput.address] + path = [route.pathInput.isNative ? ADDRESS_ZERO : route.pathInput.address] types = ['address'] - let currencyIn = route.adjustedInput + let currencyIn = route.pathInput for (const pool of route.pools) { const currencyOut = currencyIn.equals(pool.token0) ? pool.token1 : pool.token0 diff --git a/sdks/router-sdk/src/utils/getOutputAmount.ts b/sdks/router-sdk/src/utils/getOutputAmount.ts deleted file mode 100644 index ed67e57dd..000000000 --- a/sdks/router-sdk/src/utils/getOutputAmount.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { Currency, CurrencyAmount } from '@uniswap/sdk-core' -import { Pool as V4Pool } from '@uniswap/v4-sdk' -import { TPool } from './TPool' - -export async function getOutputAmount( - pool: TPool, - amountIn: CurrencyAmount -): Promise<[CurrencyAmount, TPool]> { - if (pool instanceof V4Pool) { - if (pool.involvesCurrency(amountIn.currency)) { - return await pool.getOutputAmount(amountIn) - } - if (pool.token0.wrapped.equals(amountIn.currency)) { - return await pool.getOutputAmount(CurrencyAmount.fromRawAmount(pool.token0, amountIn.quotient)) - } - if (pool.token1.wrapped.equals(amountIn.currency)) { - return await pool.getOutputAmount(CurrencyAmount.fromRawAmount(pool.token1, amountIn.quotient)) - } - } - - return await pool.getOutputAmount(amountIn.wrapped) -} diff --git a/sdks/router-sdk/src/utils/index.ts b/sdks/router-sdk/src/utils/index.ts index 75b6a7c0b..b41e75c14 100644 --- a/sdks/router-sdk/src/utils/index.ts +++ b/sdks/router-sdk/src/utils/index.ts @@ -1,6 +1,7 @@ import { Currency, Token } from '@uniswap/sdk-core' import { Pair } from '@uniswap/v2-sdk' import { Pool as V3Pool } from '@uniswap/v3-sdk' +import { Pool as V4Pool } from '@uniswap/v4-sdk' import { MixedRouteSDK } from '../entities/mixedRoute/route' import { TPool } from './TPool' @@ -16,8 +17,9 @@ export const partitionMixedRouteByProtocol = (route: MixedRouteSDK, pool: TPool): CurrencyAmount { + return CurrencyAmount.fromFractionalAmount( + getPathCurrency(amount.currency, pool), + amount.numerator, + amount.denominator + ) +} + +export function getPathCurrency(currency: Currency, pool: TPool): Currency { + // return currency if the currency matches a currency of the pool + if (pool.involvesToken(currency as Token)) { + return currency + + // return if currency.wrapped if pool involves wrapped currency + } else if (pool.involvesToken(currency.wrapped as Token)) { + return currency.wrapped + + // return native currency if pool involves native version of wrapped currency (only applies to V4) + } else if (pool instanceof V4Pool) { + if (pool.currency0.wrapped.equals(currency)) { + return pool.token0 + } else if (pool.token1.wrapped.equals(currency)) { + return pool.token1 + } + + // otherwise the token is invalid + } else { + throw new Error(`Expected currency ${currency.symbol} to be either ${pool.token0.symbol} or ${pool.token1.symbol}`) + } + + return currency // this line needed for typescript to compile +}