Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: optimized get token balances #104

Merged
merged 11 commits into from
Oct 18, 2023
11 changes: 9 additions & 2 deletions core/definitions/src/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import { NativeAddress } from "./address";
import { WormholeMessageId } from "./attestation";
import { ChainContext } from "./chain";
import { RpcConnection } from "./rpc";
import { ChainsConfig, SignedTx, TokenId, TxHash } from "./types";
import { AnyAddress, Balances, ChainsConfig, TokenId, TxHash } from "./types";
import { SignedTx } from "./types";
import { UniversalAddress } from "./universalAddress";

export interface PlatformUtils<P extends PlatformName> {
Expand All @@ -27,8 +28,14 @@ export interface PlatformUtils<P extends PlatformName> {
chain: ChainName,
rpc: RpcConnection<P>,
walletAddr: string,
token: NativeAddress<P> | UniversalAddress | "native",
token: AnyAddress,
): Promise<bigint | null>;
getBalances(
chain: ChainName,
rpc: RpcConnection<P>,
walletAddress: string,
tokens: AnyAddress[],
): Promise<Balances>;
getCurrentBlock(rpc: RpcConnection<P>): Promise<number>;

// Platform interaction utils
Expand Down
11 changes: 11 additions & 0 deletions core/definitions/src/testing/mocks/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import {
nativeIsRegistered,
NativeAddress,
UniversalAddress,
AnyAddress,
Balances,
} from "../..";
import { MockRpc } from "./rpc";
import { MockChain } from "./chain";
Expand Down Expand Up @@ -84,6 +86,15 @@ export class MockPlatform<P extends PlatformName> implements Platform<P> {
throw new Error("Method not implemented.");
}

getBalances(
chain: ChainName,
rpc: RpcConnection<PlatformName>,
walletAddress: string,
tokens: AnyAddress[],
): Promise<Balances> {
throw new Error("method not implemented");
}

getChain(chain: ChainName): ChainContext<P> {
if (chain in this.conf) return new MockChain<P>(this.conf[chain]!);
throw new Error("No configuration available for chain: " + chain);
Expand Down
14 changes: 14 additions & 0 deletions core/definitions/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ export type SequenceId = bigint;

export type SignedTx = any;

// TODO: Can we make this more dynamic?
barnjamin marked this conversation as resolved.
Show resolved Hide resolved
export type AnyAddress =
| NativeAddress<PlatformName>
| UniversalAddress
| string
| number
| Uint8Array
| number[]
| "native";

export type TokenId = ChainAddress;
export function isTokenId(thing: TokenId | any): thing is TokenId {
return (
Expand All @@ -23,6 +33,10 @@ export function isTokenId(thing: TokenId | any): thing is TokenId {
);
}

export type Balances = {
[key: string]: BigInt | null;
};

export interface Signer {
chain(): ChainName;
address(): string;
Expand Down
2 changes: 0 additions & 2 deletions platforms/cosmwasm/__tests__/unit/platform.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ import { expect, test } from "@jest/globals";
import {
chains,
chainConfigs,
DEFAULT_NETWORK,
chainToPlatform,
} from "@wormhole-foundation/connect-sdk";
import { CosmwasmPlatform } from "../../src/platform";

Expand Down
6 changes: 5 additions & 1 deletion platforms/cosmwasm/src/address.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
} from "@wormhole-foundation/connect-sdk";
import { CosmwasmPlatform } from "./platform";
import { nativeDenomToChain } from "./constants";
import { AnyCosmwasmAddress } from "./types";

declare global {
namespace Wormhole {
Expand Down Expand Up @@ -113,7 +114,10 @@ export class CosmwasmAddress implements Address {
// The denomType is "native", "ibc", or "factory"
private readonly denomType?: string;

constructor(address: string | Uint8Array | UniversalAddress) {
constructor(address: AnyCosmwasmAddress) {
if (address instanceof CosmwasmAddress) {
barnjamin marked this conversation as resolved.
Show resolved Hide resolved
Object.assign(this, address);
}
if (typeof address === "string") {
// A native denom like "uatom"
if (nativeDenomToChain.has(CosmwasmPlatform.network, address)) {
Expand Down
17 changes: 14 additions & 3 deletions platforms/cosmwasm/src/platform.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import { CosmWasmClient } from "@cosmjs/cosmwasm-stargate";
import { IbcExtension, QueryClient, setupIbcExtension } from "@cosmjs/stargate";
import {
BankExtension,
IbcExtension,
QueryClient,
setupBankExtension,
setupIbcExtension,
} from "@cosmjs/stargate";
import { TendermintClient } from "@cosmjs/tendermint-rpc";

import {
Expand Down Expand Up @@ -49,6 +55,7 @@ export module CosmwasmPlatform {
isSupportedChain,
getDecimals,
getBalance,
getBalances,
sendWait,
getCurrentBlock,
chainFromRpc,
Expand Down Expand Up @@ -111,10 +118,14 @@ export module CosmwasmPlatform {

export const getQueryClient = (
rpc: CosmWasmClient,
): QueryClient & IbcExtension => {
): QueryClient & BankExtension & IbcExtension => {
// @ts-ignore
const tmClient: TendermintClient = rpc.getTmClient()!;
return QueryClient.withExtensions(tmClient, setupIbcExtension);
return QueryClient.withExtensions(
tmClient,
setupBankExtension,
setupIbcExtension,
);
};

// cached channels from config if available
Expand Down
41 changes: 34 additions & 7 deletions platforms/cosmwasm/src/platformUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ import {
SignedTx,
Network,
PlatformToChains,
WormholeMessageId,
nativeDecimals,
chainToPlatform,
PlatformUtils,
UniversalOrNative,
Balances,
} from "@wormhole-foundation/connect-sdk";
import {
IBC_TRANSFER_PORT,
Expand All @@ -19,7 +18,13 @@ import {
import { CosmWasmClient } from "@cosmjs/cosmwasm-stargate";
import { CosmwasmPlatform } from "./platform";
import { CosmwasmAddress } from "./address";
import { IbcExtension, QueryClient, setupIbcExtension } from "@cosmjs/stargate";
import {
BankExtension,
IbcExtension,
QueryClient,
setupIbcExtension,
} from "@cosmjs/stargate";
import { AnyCosmwasmAddress } from "./types";

// forces CosmwasmUtils to implement PlatformUtils
var _: PlatformUtils<"Cosmwasm"> = CosmwasmUtils;
Expand Down Expand Up @@ -55,11 +60,12 @@ export module CosmwasmUtils {
export async function getDecimals(
chain: ChainName,
rpc: CosmWasmClient,
token: UniversalOrNative<"Cosmwasm"> | "native",
token: AnyCosmwasmAddress | "native",
): Promise<bigint> {
if (token === "native") return nativeDecimals(CosmwasmPlatform.platform);

const { decimals } = await rpc.queryContractSmart(token.toString(), {
const addrStr = new CosmwasmAddress(token).toString();
const { decimals } = await rpc.queryContractSmart(addrStr, {
token_info: {},
});
return decimals;
Expand All @@ -69,7 +75,7 @@ export module CosmwasmUtils {
chain: ChainName,
rpc: CosmWasmClient,
walletAddress: string,
token: UniversalOrNative<"Cosmwasm"> | "native",
token: AnyCosmwasmAddress | "native",
): Promise<bigint | null> {
if (token === "native") {
const { amount } = await rpc.getBalance(
Expand All @@ -79,10 +85,31 @@ export module CosmwasmUtils {
return BigInt(amount);
}

const { amount } = await rpc.getBalance(walletAddress, token.toString());
const addrStr = new CosmwasmAddress(token).toString();
const { amount } = await rpc.getBalance(walletAddress, addrStr);
return BigInt(amount);
}

export async function getBalances(
chain: ChainName,
rpc: QueryClient & BankExtension & IbcExtension,
anondev2323 marked this conversation as resolved.
Show resolved Hide resolved
walletAddress: string,
tokens: (AnyCosmwasmAddress | "native")[],
): Promise<Balances> {
const allBalances = await rpc.bank.allBalances(walletAddress);
const balancesArr = tokens.map((token) => {
const address =
token === "native"
? getNativeDenom(chain)
: new CosmwasmAddress(token).toString();
const balance = allBalances.find((balance) => balance.denom === address);
const balanceBigInt = balance ? BigInt(balance.amount) : null;
return { [address]: balanceBigInt };
});

return balancesArr.reduce((obj, item) => Object.assign(obj, item), {});
}

function getNativeDenom(chain: ChainName): string {
// TODO: required because of const map
if (CosmwasmPlatform.network === "Devnet")
Expand Down
11 changes: 2 additions & 9 deletions platforms/cosmwasm/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,13 @@ import {
import { logs as cosmosLogs } from "@cosmjs/stargate";

export type CosmwasmChainName = PlatformToChains<"Cosmwasm">;
export type UniversalOrCosmwasm = UniversalOrNative<"Cosmwasm"> | string;
export type UniversalOrCosmwasm = UniversalOrNative<"Cosmwasm">;
export type AnyCosmwasmAddress = UniversalOrCosmwasm | string | Uint8Array;

export interface WrappedRegistryResponse {
address: string;
}

export const toCosmwasmAddrString = (addr: UniversalOrCosmwasm) =>
typeof addr === "string"
? addr
: (addr instanceof UniversalAddress
? addr.toNative("Cosmwasm")
: addr
).unwrap();

// TODO: do >1 key at a time
export const searchCosmosLogs = (
key: string,
Expand Down
8 changes: 6 additions & 2 deletions platforms/evm/src/address.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Address, UniversalAddress } from '@wormhole-foundation/connect-sdk';

import { ethers } from 'ethers';
import { AnyEvmAddress } from './types';

declare global {
namespace Wormhole {
Expand All @@ -16,10 +17,13 @@ export const EvmZeroAddress = ethers.ZeroAddress;
export class EvmAddress implements Address {
static readonly byteSize = 20;

//stored as checksum address
// stored as checksum address
private readonly address: string;

constructor(address: string | Uint8Array | UniversalAddress) {
constructor(address: AnyEvmAddress) {
if (address instanceof EvmAddress) {
Object.assign(this, address);
}
if (typeof address === 'string') {
if (!EvmAddress.isValidAddress(address))
throw new Error(
Expand Down
1 change: 1 addition & 0 deletions platforms/evm/src/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ export module EvmPlatform {
isSupportedChain,
getDecimals,
getBalance,
getBalances,
sendWait,
getCurrentBlock,
chainFromRpc,
Expand Down
36 changes: 26 additions & 10 deletions platforms/evm/src/platformUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import {
nativeDecimals,
chainToPlatform,
PlatformUtils,
UniversalOrNative,
Balances,
} from '@wormhole-foundation/connect-sdk';

import { Provider } from 'ethers';
import { evmChainIdToNetworkChainPair } from './constants';
import { EvmAddress, EvmZeroAddress } from './address';
import { EvmContracts } from './contracts';
import { EvmPlatform } from './platform';
import { AnyEvmAddress } from './types';

// forces EvmUtils to implement PlatformUtils
var _: PlatformUtils<'Evm'> = EvmUtils;
Expand Down Expand Up @@ -48,32 +49,47 @@ export module EvmUtils {
export async function getDecimals(
chain: ChainName,
rpc: Provider,
token: UniversalOrNative<'Evm'> | 'native',
token: AnyEvmAddress | 'native',
): Promise<bigint> {
if (token === 'native') return nativeDecimals(EvmPlatform.platform);

const tokenContract = EvmContracts.getTokenImplementation(
rpc,
token.toString(),
new EvmAddress(token).toString(),
);
const decimals = await tokenContract.decimals();
return decimals;
return tokenContract.decimals();
}

export async function getBalance(
chain: ChainName,
rpc: Provider,
walletAddr: string,
token: UniversalOrNative<'Evm'> | 'native',
token: AnyEvmAddress | 'native',
): Promise<bigint | null> {
if (token === 'native') return await rpc.getBalance(walletAddr);
if (token === 'native') return rpc.getBalance(walletAddr);

const tokenImpl = EvmContracts.getTokenImplementation(
rpc,
token.toString(),
new EvmAddress(token).toString(),
);
const balance = await tokenImpl.balanceOf(walletAddr);
return balance;
return tokenImpl.balanceOf(walletAddr);
}

export async function getBalances(
chain: ChainName,
rpc: Provider,
walletAddr: string,
tokens: (AnyEvmAddress | 'native')[],
): Promise<Balances> {
const balancesArr = await Promise.all(
tokens.map(async (token) => {
const balance = await getBalance(chain, rpc, walletAddr, token);
const address =
token === 'native' ? 'native' : new EvmAddress(token).toString();
return { [address]: balance };
}),
);
return balancesArr.reduce((obj, item) => Object.assign(obj, item), {});
}

export async function sendWait(
Expand Down
8 changes: 2 additions & 6 deletions platforms/evm/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ export const unusedNonce = 0;
export const unusedArbiterFee = 0n;

export type EvmChainName = PlatformToChains<'Evm'>;
export type UniversalOrEvm = UniversalOrNative<'Evm'> | string;

export const toEvmAddrString = (addr: UniversalOrEvm) =>
typeof addr === 'string'
? addr
: (addr instanceof UniversalAddress ? addr.toNative('Evm') : addr).unwrap();
export type UniversalOrEvm = UniversalOrNative<'Evm'>;
export type AnyEvmAddress = UniversalOrEvm | string | Uint8Array;

export const addFrom = (txReq: TransactionRequest, from: string) => ({
...txReq,
Expand Down
2 changes: 0 additions & 2 deletions platforms/solana/__tests__/unit/platform.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import {

import { SolanaPlatform } from '../../src';

import { PublicKey } from '@solana/web3.js';

// @ts-ignore -- this is the mock we import above
import { getDefaultProvider } from '@solana/web3.js';

Expand Down
Loading
Loading