diff --git a/src/hooks/wallets/mpc/__tests__/useMPC.test.ts b/src/hooks/wallets/mpc/__tests__/useMPC.test.ts new file mode 100644 index 0000000000..8665b5ec9d --- /dev/null +++ b/src/hooks/wallets/mpc/__tests__/useMPC.test.ts @@ -0,0 +1,225 @@ +import * as useOnboard from '@/hooks/wallets/useOnboard' +import { renderHook, waitFor } from '@/tests/test-utils' +import { getMPCCoreKitInstance, setMPCCoreKitInstance, useInitMPC } from '../useMPC' +import * as useChains from '@/hooks/useChains' +import { type ChainInfo, RPC_AUTHENTICATION } from '@safe-global/safe-gateway-typescript-sdk' +import { hexZeroPad } from 'ethers/lib/utils' +import { ONBOARD_MPC_MODULE_LABEL } from '@/services/mpc/module' +import { type Web3AuthMPCCoreKit, COREKIT_STATUS } from '@web3auth/mpc-core-kit' +import { type EIP1193Provider, type OnboardAPI } from '@web3-onboard/core' + +jest.mock('@web3auth/mpc-core-kit', () => ({ + ...jest.requireActual('@web3auth/mpc-core-kit'), + Web3AuthMPCCoreKit: jest.fn(), +})) + +type MPCProvider = Web3AuthMPCCoreKit['provider'] + +/** + * Mock for creating and initializing the MPC Core Kit + */ +class MockMPCCoreKit { + provider: MPCProvider | null = null + status = COREKIT_STATUS.NOT_INITIALIZED + private mockState + private mockProvider + + /** + * The parameters are set in the mock MPC Core Kit after init() get's called + * + * @param mockState + * @param mockProvider + */ + constructor(mockState: COREKIT_STATUS, mockProvider: MPCProvider) { + this.mockState = mockState + this.mockProvider = mockProvider + } + + init() { + this.status = this.mockState + this.provider = this.mockProvider + return Promise.resolve() + } +} + +/** + * Small helper class that implements registering RPC event listeners and event emiting. + * Used to test that events onboard relies on are getting called correctly + */ +class EventEmittingMockProvider { + private chainChangedListeners: Function[] = [] + + addListener(event: string, listener: Function) { + if (event === 'chainChanged') { + this.chainChangedListeners.push(listener) + } + } + + emit(event: string, ...args: any[]) { + this.chainChangedListeners.forEach((listener) => listener(...args)) + } +} + +describe('useInitMPC', () => { + beforeEach(() => { + jest.resetAllMocks() + }) + it('should set the coreKit if user is not logged in yet', async () => { + const connectWalletSpy = jest.fn().mockImplementation(() => Promise.resolve()) + jest.spyOn(useOnboard, 'connectWallet').mockImplementation(connectWalletSpy) + jest.spyOn(useOnboard, 'getConnectedWallet').mockReturnValue(null) + jest.spyOn(useOnboard, 'default').mockReturnValue({ + state: { + get: () => ({ + wallets: [], + walletModules: [], + }), + }, + } as unknown as OnboardAPI) + jest.spyOn(useChains, 'useCurrentChain').mockReturnValue({ + chainId: '5', + chainName: 'Goerli', + blockExplorerUriTemplate: { + address: 'https://goerli.someprovider.io/{address}', + txHash: 'https://goerli.someprovider.io/{txHash}', + api: 'https://goerli.someprovider.io/', + }, + nativeCurrency: { + decimals: 18, + logoUri: 'https://logo.goerli.com', + name: 'Goerli ETH', + symbol: 'ETH', + }, + rpcUri: { + authentication: RPC_AUTHENTICATION.NO_AUTHENTICATION, + value: 'https://goerli.somerpc.io', + }, + } as unknown as ChainInfo) + + const mockWeb3AuthMpcCoreKit = jest.spyOn(require('@web3auth/mpc-core-kit'), 'Web3AuthMPCCoreKit') + mockWeb3AuthMpcCoreKit.mockImplementation(() => { + return new MockMPCCoreKit(COREKIT_STATUS.INITIALIZED, null) + }) + + renderHook(() => useInitMPC()) + + await waitFor(() => { + expect(getMPCCoreKitInstance()).toBeDefined() + expect(connectWalletSpy).not.toBeCalled() + }) + }) + + it('should call connectWallet after rehydrating a web3auth session', async () => { + const connectWalletSpy = jest.fn().mockImplementation(() => Promise.resolve()) + jest.spyOn(useOnboard, 'connectWallet').mockImplementation(connectWalletSpy) + jest.spyOn(useOnboard, 'getConnectedWallet').mockReturnValue(null) + jest.spyOn(useOnboard, 'default').mockReturnValue({ + state: { + get: () => ({ + wallets: [], + walletModules: [], + }), + }, + } as unknown as OnboardAPI) + jest.spyOn(useChains, 'useCurrentChain').mockReturnValue({ + chainId: '5', + chainName: 'Goerli', + blockExplorerUriTemplate: { + address: 'https://goerli.someprovider.io/{address}', + txHash: 'https://goerli.someprovider.io/{txHash}', + api: 'https://goerli.someprovider.io/', + }, + nativeCurrency: { + decimals: 18, + logoUri: 'https://logo.goerli.com', + name: 'Goerli ETH', + symbol: 'ETH', + }, + rpcUri: { + authentication: RPC_AUTHENTICATION.NO_AUTHENTICATION, + value: 'https://goerli.somerpc.io', + }, + } as unknown as ChainInfo) + + const mockWeb3AuthMpcCoreKit = jest.spyOn(require('@web3auth/mpc-core-kit'), 'Web3AuthMPCCoreKit') + const mockProvider = jest.fn() + mockWeb3AuthMpcCoreKit.mockImplementation(() => { + return new MockMPCCoreKit(COREKIT_STATUS.INITIALIZED, mockProvider as unknown as MPCProvider) + }) + + renderHook(() => useInitMPC()) + + await waitFor(() => { + expect(connectWalletSpy).toBeCalled() + expect(getMPCCoreKitInstance()).toBeDefined() + }) + }) + + it('should copy event handlers and emit chainChanged if the current chain is updated', async () => { + const connectWalletSpy = jest.fn().mockImplementation(() => Promise.resolve()) + jest.spyOn(useOnboard, 'connectWallet').mockImplementation(connectWalletSpy) + jest.spyOn(useOnboard, 'getConnectedWallet').mockReturnValue({ + address: hexZeroPad('0x1', 20), + label: ONBOARD_MPC_MODULE_LABEL, + chainId: '1', + provider: {} as unknown as EIP1193Provider, + }) + jest.spyOn(useOnboard, 'default').mockReturnValue({ + state: { + get: () => ({ + wallets: [], + walletModules: [], + }), + }, + } as unknown as OnboardAPI) + jest.spyOn(useChains, 'useCurrentChain').mockReturnValue({ + chainId: '5', + chainName: 'Goerli', + blockExplorerUriTemplate: { + address: 'https://goerli.someprovider.io/{address}', + txHash: 'https://goerli.someprovider.io/{txHash}', + api: 'https://goerli.someprovider.io/', + }, + nativeCurrency: { + decimals: 18, + logoUri: 'https://logo.goerli.com', + name: 'Goerli ETH', + symbol: 'ETH', + }, + rpcUri: { + authentication: RPC_AUTHENTICATION.NO_AUTHENTICATION, + value: 'https://goerli.somerpc.io', + }, + } as unknown as ChainInfo) + + const mockWeb3AuthMpcCoreKit = jest.spyOn(require('@web3auth/mpc-core-kit'), 'Web3AuthMPCCoreKit') + const mockChainChangedListener = jest.fn() + const mockProviderBefore = { + listeners: (eventName: string) => { + if (eventName === 'chainChanged') { + return [mockChainChangedListener] + } + }, + } + + setMPCCoreKitInstance({ + provider: mockProviderBefore, + } as unknown as Web3AuthMPCCoreKit) + + const mockProvider = new EventEmittingMockProvider() + mockWeb3AuthMpcCoreKit.mockImplementation(() => { + return new MockMPCCoreKit( + require('@web3auth/mpc-core-kit').COREKIT_STATUS.INITIALIZED, + mockProvider as unknown as MPCProvider, + ) + }) + + renderHook(() => useInitMPC()) + + await waitFor(() => { + expect(mockChainChangedListener).toHaveBeenCalledWith('0x5') + expect(getMPCCoreKitInstance()).toBeDefined() + expect(connectWalletSpy).not.toBeCalled() + }) + }) +}) diff --git a/src/hooks/wallets/mpc/useMPC.ts b/src/hooks/wallets/mpc/useMPC.ts index cf1472ab13..b7d66c444f 100644 --- a/src/hooks/wallets/mpc/useMPC.ts +++ b/src/hooks/wallets/mpc/useMPC.ts @@ -1,6 +1,6 @@ import { useEffect } from 'react' import ExternalStore from '@/services/ExternalStore' -import { Web3AuthMPCCoreKit, WEB3AUTH_NETWORK, COREKIT_STATUS } from '@web3auth/mpc-core-kit' +import { Web3AuthMPCCoreKit, WEB3AUTH_NETWORK } from '@web3auth/mpc-core-kit' import { CHAIN_NAMESPACES } from '@web3auth/base' import { WEB3_AUTH_CLIENT_ID } from '@/config/constants' @@ -29,6 +29,14 @@ export const useInitMPC = () => { tickerName: chain.nativeCurrency.name, } + const currentInstance = getStore() + let previousChainChangedListeners: Function[] = [] + if (currentInstance?.provider) { + // We are already connected. We copy onboards event listener for the chainChanged event to propagate a potentially new chainId + const oldProvider = currentInstance.provider + previousChainChangedListeners = oldProvider.listeners('chainChanged') + } + const web3AuthCoreKit = new Web3AuthMPCCoreKit({ web3AuthClientId: WEB3_AUTH_CLIENT_ID, // Available networks are "sapphire_devnet", "sapphire_mainnet" @@ -43,33 +51,27 @@ export const useInitMPC = () => { .init() .then(() => { setStore(web3AuthCoreKit) - // If rehydration was successful, connect to onboard - if (web3AuthCoreKit.status === COREKIT_STATUS.INITIALIZED) { - console.log('Logged in', web3AuthCoreKit) - // await mpcCoreKit.enableMFA({}) - const connectedWallet = getConnectedWallet(onboard.state.get().wallets) - if (!connectedWallet) { - connectWallet(onboard, { - autoSelect: { - label: ONBOARD_MPC_MODULE_LABEL, - disableModals: true, - }, - }).catch((reason) => console.error('Error connecting to MPC module:', reason)) - } else { - // To propagate the changedChain we disconnect and connect - onboard - .disconnectWallet({ - label: ONBOARD_MPC_MODULE_LABEL, - }) - .then(() => - connectWallet(onboard, { - autoSelect: { - label: ONBOARD_MPC_MODULE_LABEL, - disableModals: true, - }, - }), - ) + if (!web3AuthCoreKit.provider) { + return + } + const connectedWallet = getConnectedWallet(onboard.state.get().wallets) + if (!connectedWallet) { + connectWallet(onboard, { + autoSelect: { + label: ONBOARD_MPC_MODULE_LABEL, + disableModals: true, + }, + }).catch((reason) => console.error('Error connecting to MPC module:', reason)) + } else { + const newProvider = web3AuthCoreKit.provider + + // To propagate the changedChain we disconnect and connect + if (previousChainChangedListeners.length > 0 && newProvider) { + previousChainChangedListeners.forEach((previousListener) => + newProvider.addListener('chainChanged', (...args: []) => previousListener(...args)), + ) + newProvider.emit('chainChanged', `0x${Number(chainConfig.chainId).toString(16)}`) } } }) @@ -79,4 +81,6 @@ export const useInitMPC = () => { export const getMPCCoreKitInstance = getStore +export const setMPCCoreKitInstance = setStore + export default useStore