diff --git a/app/api/execute/route.ts b/app/api/execute/route.ts index a470216..096ba01 100644 --- a/app/api/execute/route.ts +++ b/app/api/execute/route.ts @@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from "next/server"; import { getToolsFromContracts } from "@/utils/generateToolFromABI"; import { routeBodySchema } from "./schemas"; import { contractCollection } from "@/utils/collections"; +import { hashPin } from "@/utils/crypt"; export const runtime = "nodejs"; @@ -21,7 +22,9 @@ export async function POST(req: NextRequest) { return NextResponse.json({ error: errorMessages }, { status: 400 }); } - const { toolCall, didToken } = result.data; + const { toolCall, didToken, pin } = result.data; + + const encryptionContext = await hashPin(pin); // parse contractAddress from toolCall.name; Should be in format `${contractKey}_${functionName}_${overload function index}`` const contractKey = parseInt(toolCall.name.split("_").at(0) as string, 10); @@ -40,9 +43,11 @@ export async function POST(req: NextRequest) { } try { - const tool = getToolsFromContracts([contract], didToken).find( - (t) => t.name === toolCall.name, - ); + const tool = getToolsFromContracts( + [contract], + didToken, + encryptionContext, + ).find((t) => t.name === toolCall.name); if (!tool) { return NextResponse.json( diff --git a/app/api/execute/schemas.ts b/app/api/execute/schemas.ts index eb94c3c..da705f1 100644 --- a/app/api/execute/schemas.ts +++ b/app/api/execute/schemas.ts @@ -17,4 +17,10 @@ export const routeBodySchema = z.object({ invalid_type_error: "didToken must be a string", }) .min(1, "didToken cannot be an empty string"), + pin: z + .string({ + required_error: "pin is required", + invalid_type_error: "pin must be a string", + }) + .min(1, "pin cannot be an empty string"), }); diff --git a/app/api/wallet/route.ts b/app/api/wallet/route.ts index ad237b9..cd154c9 100644 --- a/app/api/wallet/route.ts +++ b/app/api/wallet/route.ts @@ -1,21 +1,34 @@ import { getWalletUUIDandAccessKey } from "@/utils/tee"; import { Magic } from "@magic-sdk/admin"; import { NextRequest, NextResponse } from "next/server"; +import { hashPin } from "@/utils/crypt"; export const runtime = "nodejs"; const magic = await Magic.init(process.env.MAGIC_SECRET_KEY); export async function GET(req: NextRequest) { try { - console.log(req.nextUrl.searchParams.get("didToken")); const didToken = req.nextUrl.searchParams.get("didToken") ?? ""; + const pin = req.nextUrl.searchParams.get("pin"); + + if (!didToken) throw new Error("TOKEN missing"); const userMetadata = await magic.users.getMetadataByToken(didToken); const publicAddress = userMetadata.publicAddress ?? ""; + const encryptionContext = pin ? await hashPin(pin) : undefined; - const result = await getWalletUUIDandAccessKey(publicAddress); - - return NextResponse.json({ wallet_address: result.wallet_address }); + try { + const result = await getWalletUUIDandAccessKey( + publicAddress, + encryptionContext, + ); + return NextResponse.json({ wallet_address: result.wallet_address }); + } catch (e) { + return NextResponse.json( + { error: "Wallet does not exist and no PIN was provided" }, + { status: 400 }, + ); + } } catch (e: any) { return NextResponse.json({ error: e.message }, { status: e.status ?? 500 }); } diff --git a/components/ChatMessageBubble.tsx b/components/ChatMessageBubble.tsx index c4a7fda..eb75173 100644 --- a/components/ChatMessageBubble.tsx +++ b/components/ChatMessageBubble.tsx @@ -12,6 +12,7 @@ import { ToolArgsTable } from "./ToolArgsTable"; import { useContracts } from "@/utils/useContracts"; import { CHAINS } from "@/constants"; import { IContract } from "@/types"; +import { usePinInput } from "./PinInput"; type IToolCall = { name: string; @@ -117,6 +118,7 @@ export function ToolCallMessageBubble(props: { message: Message }) { const { colorClassName, alignmentClassName, icon } = getStyleForRole( props.message.role, ); + const { getPin, pinInput } = usePinInput(); let content: { text: string; toolCall?: IToolCall } = { text: "", @@ -133,11 +135,14 @@ export function ToolCallMessageBubble(props: { message: Message }) { setLoading(true); try { + const pin = await getPin(); + if (!pin) throw new Error("Invalid PIN"); const resp = await fetch("/api/execute", { method: "POST", body: JSON.stringify({ toolCall, didToken, + pin, disabledContractKeys: disabledKeys, }), }); @@ -227,6 +232,8 @@ export function ToolCallMessageBubble(props: { message: Message }) { {renderContent} + + {pinInput} ); } diff --git a/components/MagicProvider.tsx b/components/MagicProvider.tsx index c4136c9..e8d0eee 100644 --- a/components/MagicProvider.tsx +++ b/components/MagicProvider.tsx @@ -3,6 +3,7 @@ import { Magic } from "magic-sdk"; import { Web3 } from "web3"; import { createContext, useContext, useEffect, useMemo, useState } from "react"; +import { usePinInput } from "./PinInput"; // Create and export the context export const MagicContext = createContext<{ @@ -39,6 +40,13 @@ const MagicProvider = ({ children }: any) => { const [address, setAddress] = useState(null); const [didToken, setDidToken] = useState(null); + const { getPin, pinInput } = usePinInput({ + title: "Enter your TEE Wallet PIN", + description: + "You will be asked to enter this value whenever you try to execute a transaction", + allowCancel: false, + }); + useEffect(() => { if (process.env.NEXT_PUBLIC_MAGIC_API_KEY) { const magic = new Magic(process.env.NEXT_PUBLIC_MAGIC_API_KEY || "", { @@ -69,14 +77,19 @@ const MagicProvider = ({ children }: any) => { let didToken = await magic.user.getIdToken(); setDidToken(didToken); - const response = await fetch(`/api/wallet?didToken=${didToken}`); + let response = await fetch(`/api/wallet?didToken=${didToken}`); + if (!response.ok) { + const pin = await getPin(); + response = await fetch(`/api/wallet?didToken=${didToken}&pin=${pin}`); + } + const json = await response.json(); setTEEWalletAddress(json.wallet_address); } setIsLoading(false); }; checkIfLoggedIn(); - }, [magic]); + }, [magic, getPin]); const handleLogin = async () => { if (!magic) return; @@ -101,6 +114,7 @@ const MagicProvider = ({ children }: any) => { setIsLoggedIn(false); setDidToken(null); setAddress(null); + setTEEWalletAddress(null); }; const value = useMemo(() => { @@ -124,6 +138,7 @@ const MagicProvider = ({ children }: any) => { }} > {children} + {pinInput} ); }; diff --git a/components/PinInput.tsx b/components/PinInput.tsx new file mode 100644 index 0000000..f570cd8 --- /dev/null +++ b/components/PinInput.tsx @@ -0,0 +1,109 @@ +import { + InputOTP, + InputOTPGroup, + InputOTPSlot, +} from "@/components/ui/input-otp"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogContent, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { + AlertDialogCancel, + AlertDialogDescription, +} from "@radix-ui/react-alert-dialog"; + +export function PinInput(props: { + open: boolean; + title: string; + description: string; + onSubmit?: (s: string) => void; + onCancel?: () => void; + pinLength?: number; +}) { + const [value, setValue] = useState(""); + const length = props.pinLength ?? 4; + + useEffect(() => { + if (props.open) setValue(""); + }, [props.open]); + + return ( + + +
{ + e.preventDefault(); + if (value.length === length) props.onSubmit?.(value); + }} + > + + {props.title} + {props.description && ( + + {props.description} + + )} + + + {Array.from({ length }, (_, i) => ( + + ))} + + + + + + {props.onCancel && ( + props.onCancel?.()}> + Cancel + + )} + + Submit + + +
+
+
+ ); +} + +export const usePinInput = ({ + title = "Enter PIN", + description = "", + allowCancel = true, +} = {}) => { + const [isPinOpen, setIsPinOpen] = useState(false); + const pinPromiseRef = useRef<(s?: string) => void>(); + const pinInput = ( + + ); + + const getPin = useCallback(async () => { + setIsPinOpen(true); + const pin = await new Promise((resolve) => { + pinPromiseRef.current = resolve; + }); + setIsPinOpen(false); + return pin; + }, []); + + return { pinInput, getPin }; +}; diff --git a/components/ui/input-otp.tsx b/components/ui/input-otp.tsx new file mode 100644 index 0000000..f66fcfa --- /dev/null +++ b/components/ui/input-otp.tsx @@ -0,0 +1,71 @@ +"use client" + +import * as React from "react" +import { OTPInput, OTPInputContext } from "input-otp" +import { Dot } from "lucide-react" + +import { cn } from "@/lib/utils" + +const InputOTP = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, containerClassName, ...props }, ref) => ( + +)) +InputOTP.displayName = "InputOTP" + +const InputOTPGroup = React.forwardRef< + React.ElementRef<"div">, + React.ComponentPropsWithoutRef<"div"> +>(({ className, ...props }, ref) => ( +
+)) +InputOTPGroup.displayName = "InputOTPGroup" + +const InputOTPSlot = React.forwardRef< + React.ElementRef<"div">, + React.ComponentPropsWithoutRef<"div"> & { index: number } +>(({ index, className, ...props }, ref) => { + const inputOTPContext = React.useContext(OTPInputContext) + const { char, hasFakeCaret, isActive } = inputOTPContext.slots[index] + + return ( +
+ {char} + {hasFakeCaret && ( +
+
+
+ )} +
+ ) +}) +InputOTPSlot.displayName = "InputOTPSlot" + +const InputOTPSeparator = React.forwardRef< + React.ElementRef<"div">, + React.ComponentPropsWithoutRef<"div"> +>(({ ...props }, ref) => ( +
+ +
+)) +InputOTPSeparator.displayName = "InputOTPSeparator" + +export { InputOTP, InputOTPGroup, InputOTPSlot, InputOTPSeparator } diff --git a/package-lock.json b/package-lock.json index 845e0cf..f3029f5 100644 --- a/package-lock.json +++ b/package-lock.json @@ -36,6 +36,7 @@ "eslint-config-next": "13.4.12", "ethers": "^6.13.2", "etherscan-api": "^10.3.0", + "input-otp": "^1.2.4", "langchain": "^0.2.12", "lucide-react": "^0.428.0", "magic-sdk": "^28.5.0", @@ -5997,6 +5998,15 @@ "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", "license": "ISC" }, + "node_modules/input-otp": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/input-otp/-/input-otp-1.2.4.tgz", + "integrity": "sha512-md6rhmD+zmMnUh5crQNSQxq3keBRYvE3odbr4Qb9g2NWzQv9azi+t1a3X4TBTbh98fsGHgEEJlzbe1q860uGCA==", + "peerDependencies": { + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + } + }, "node_modules/internal-slot": { "version": "1.0.7", "resolved": "https://registry.npmjs.org/internal-slot/-/internal-slot-1.0.7.tgz", diff --git a/package.json b/package.json index 7dce402..91416cc 100644 --- a/package.json +++ b/package.json @@ -41,6 +41,7 @@ "eslint-config-next": "13.4.12", "ethers": "^6.13.2", "etherscan-api": "^10.3.0", + "input-otp": "^1.2.4", "langchain": "^0.2.12", "lucide-react": "^0.428.0", "magic-sdk": "^28.5.0", diff --git a/utils/crypt.ts b/utils/crypt.ts new file mode 100644 index 0000000..b0a0991 --- /dev/null +++ b/utils/crypt.ts @@ -0,0 +1,11 @@ +import crypto from "crypto"; + +export async function hashPin(pin: string) { + try { + const hash = crypto.createHash("sha512"); + hash.update(pin); + return hash.digest("hex"); + } catch (error) { + console.error("Error hashing password:", error); + } +} diff --git a/utils/generateToolFromABI.ts b/utils/generateToolFromABI.ts index bc1d470..42bf77e 100644 --- a/utils/generateToolFromABI.ts +++ b/utils/generateToolFromABI.ts @@ -12,11 +12,12 @@ type IZodGeneric = ZodBoolean | ZodNumber | ZodString; export const getToolsFromContracts = ( contracts: IContract[], didToken?: string, + encryptionContext?: string, ) => contracts.flatMap((contract) => (contract.abi ?? []) .filter((f: any) => f.name && f.type === "function") - .map(generateToolFromABI(contract, didToken)), + .map(generateToolFromABI(contract, didToken, encryptionContext)), ); export const getContractABIDescriptions = ( @@ -39,7 +40,7 @@ export const getContractABIDescriptions = ( ]); const generateToolFromABI = - (contract: IContract, didToken?: string) => + (contract: IContract, didToken?: string, encryptionContext?: string) => (func: AbiFunction, _: number, abi: AbiFunction[]): any => { const name = getToolName(contract, func, abi); const abiDescription = contract.abiDescriptions?.find((d) => @@ -66,7 +67,7 @@ const generateToolFromABI = contract.context ?? "" }`, schema: z.object(schema), - func: getToolFunction(didToken, contract, func), + func: getToolFunction(didToken, encryptionContext, contract, func), }); }; @@ -121,10 +122,23 @@ const getToolDescription = (contract: IContract, func: AbiFunction) => { }; const getToolFunction = - (didToken: string | undefined, contract: IContract, func: AbiFunction) => + ( + didToken: string | undefined, + encryptionContext: string | undefined, + contract: IContract, + func: AbiFunction, + ) => async (args: Record): Promise => { // This function should return a string according to the link hence the stringifed JSON // https://js.langchain.com/v0.2/docs/how_to/custom_tools/#dynamicstructuredtool + if (!encryptionContext) { + return JSON.stringify({ + message: "No encryptionContext", + status: "failure", + payload: {}, + }); + } + if (!didToken) { return JSON.stringify({ message: "No didToken", @@ -146,6 +160,7 @@ const getToolFunction = value: args.transactionValue ?? 0, args: ensuredArgOrder, publicAddress, + encryptionContext, }); const { transactionHash, message, status } = txReceipt; diff --git a/utils/kvCache.ts b/utils/kvCache.ts index 3cfdf47..ab983e1 100644 --- a/utils/kvCache.ts +++ b/utils/kvCache.ts @@ -16,7 +16,7 @@ export class KVCache { } public async delete(key: string): Promise { - await kv.set(this.getStorageKey(key), JSON.stringify({})); + await kv.del(this.getStorageKey(key)); } public async get(key: string): Promise { diff --git a/utils/tee.ts b/utils/tee.ts index 1d961f4..ab127df 100644 --- a/utils/tee.ts +++ b/utils/tee.ts @@ -75,15 +75,17 @@ async function signTransaction({ payload, access_key, wallet_id, + encryption_context, }: { payload: IWalletTxPayload; access_key: string; wallet_id: string; + encryption_context: string; }) { try { const response = await axiosInstance.post("/wallet/sign_transaction", { payload: payload, - encryption_context: "0000", + encryption_context: encryption_context, access_key: access_key, wallet_id: wallet_id, }); @@ -99,6 +101,7 @@ async function signTransaction({ export async function getWalletUUIDandAccessKey( publicAddress: string, + encryptionContext?: string, ): Promise { try { // pa = public address @@ -115,6 +118,10 @@ export async function getWalletUUIDandAccessKey( }; } console.log(`pa:${publicAddress} NOT in cache`); + + if (!encryptionContext) + throw new Error("Wallet not found and missing encryption context"); + const walletGroups = await getWalletGroups(); // For now assume the first wallet group in case the magic tenant has more than one @@ -123,7 +130,7 @@ export async function getWalletUUIDandAccessKey( const walletResponse = await createWallet({ wallet_group_id: walletGroup.uuid, network: "mainnet", - encryption_context: "0000", + encryption_context: encryptionContext, }); const wallet = walletResponse.data; @@ -152,17 +159,19 @@ export async function getTransactionReceipt({ value: rawValue, args, publicAddress, + encryptionContext, }: { contract: IContract; functionName: string; value: number; args: any[]; publicAddress: string; + encryptionContext: string; }): Promise { try { // TODO: wrap in Error class to denote ABI error const { wallet_id, wallet_address, access_key } = - await getWalletUUIDandAccessKey(publicAddress); + await getWalletUUIDandAccessKey(publicAddress, encryptionContext); const RPC_URL = `${CHAINS[contract.chainId].rpcURI}${ALCHEMY_KEY}`; const provider = new ethers.JsonRpcProvider(RPC_URL); @@ -202,8 +211,9 @@ export async function getTransactionReceipt({ const signedTx = await signTransaction({ payload, - access_key: access_key, - wallet_id: wallet_id, + access_key, + wallet_id, + encryption_context: encryptionContext, }); console.log({ signedTx });