From ab96af1adbe0f987138da99d183af4240dabd586 Mon Sep 17 00:00:00 2001 From: Shigma Date: Tue, 7 Jan 2025 21:03:43 +0800 Subject: [PATCH] feat(satori): support async iterator bridge --- adapters/satori/src/bot.ts | 46 +++++++++++++++++++++++++++++----- packages/core/src/index.ts | 33 +++++++++--------------- packages/core/src/internal.ts | 39 ++++++++++++++++++++++++++++ packages/protocol/src/index.ts | 6 ++--- 4 files changed, 94 insertions(+), 30 deletions(-) diff --git a/adapters/satori/src/bot.ts b/adapters/satori/src/bot.ts index d74b06b9..c66ff567 100644 --- a/adapters/satori/src/bot.ts +++ b/adapters/satori/src/bot.ts @@ -5,13 +5,47 @@ function createInternal(bot: SatoriBot, prefix = '') { apply(target, thisArg, args) { const key = prefix.slice(1) bot.logger.debug('[request.internal]', key, args) - const form = new FormData() - args = JsonForm.dump(args, '$', form) - if (![...form.entries()].length) { - return bot.http.post('/v1/' + bot.getInternalUrl(`/_api/${key}`, {}, true), args) + + const impl = async (pagination = false) => { + const request = await JsonForm.encode(args) + if (pagination) { + request.headers.set('Satori-Pagination', 'true') + } + const response = await bot.http('/v1/' + bot.getInternalUrl(`/_api/${key}`, {}, true), { + method: 'POST', + headers: Object.fromEntries(request.headers.entries()), + data: request.body, + responseType: 'arraybuffer', + }) + return await JsonForm.decode({ body: response.data, headers: response.headers }) + } + + let promise: Promise | undefined + const result = {} as Promise & AsyncIterableIterator + for (const key of ['then', 'catch', 'finally']) { + result[key] = (...args: any[]) => { + return (promise ??= impl())[key](...args) + } } - form.append('$', JSON.stringify(args)) - return bot.http.post('/v1/' + bot.getInternalUrl(`/_api/${key}`, {}, true), form) + + let pagination: { data: any[]; next?: any } | undefined + result.next = async function () { + pagination ??= await impl(true) + if (!pagination.data) throw new Error('Invalid pagination response') + if (pagination.data.length) return { done: false, value: pagination.data.shift() } + if (!pagination.next) return { done: true, value: undefined } + args = pagination.next + pagination = await impl(true) + return this.next() + } + result[Symbol.asyncIterator] = function () { + return this + } + result[Symbol.for('satori.pagination')] = () => { + return impl(true) + } + + return result }, get(target, key, receiver) { if (typeof key === 'symbol' || key in target) { diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 6fb023b1..c6fadf94 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -202,31 +202,22 @@ export class Satori extends Service { this.defineInternalRoute('/_api/:name', async ({ bot, headers, params, method, body }) => { if (method !== 'POST') return { status: 405 } - const type = headers['content-type'] - let args: any - if (type?.startsWith('multipart/form-data')) { - const response = new globalThis.Response(body, { headers }) - const form = await response.formData() - const rawData = form.get('$') as string - try { - args = JSON.parse(rawData) - } catch { - return { status: 400 } - } - args = JsonForm.load(args, '$', form) - } else { - args = JSON.parse(new TextDecoder().decode(body)) - } + const args = await JsonForm.decode({ body, headers: new Headers(headers) }) + if (!args) return { status: 400 } try { - const result = await bot.internal[params.name](...args) - const body = new TextEncoder().encode(JSON.stringify(result)) - const headers = new Headers() - if (body.byteLength) { - headers.set('content-type', 'application/json') + let result = bot.internal[params.name](...args) + if (headers['satori-pagination']) { + if (!result?.[Symbol.for('satori.pagination')]) { + return { status: 400, statusText: 'This API does not support pagination' } + } + result = await result[Symbol.for('satori.pagination')]() + } else { + result = await result } - return { body, headers, status: 200 } + return { ...await JsonForm.encode(result), status: 200 } } catch (error) { if (!ctx.http.isError(error) || !error.response) throw error + // FIXME: missing response body return error.response } }) diff --git a/packages/core/src/internal.ts b/packages/core/src/internal.ts index 02f4ce04..b57174bf 100644 --- a/packages/core/src/internal.ts +++ b/packages/core/src/internal.ts @@ -132,4 +132,43 @@ export namespace JsonForm { return dump(value, `${path}.${key}`, form) }) } + + export interface Body { + body: ArrayBuffer + headers: Headers + } + + export async function decode(body: Body) { + const type = body.headers.get('content-type') + if (type.startsWith('multipart/form-data')) { + const response = new globalThis.Response(body.body, { headers: body.headers }) + const form = await response.formData() + const json = form.get('$') as string + return load(JSON.parse(json), '$', form) + } else if (type.startsWith('application/json')) { + return JSON.parse(new TextDecoder().decode(body.body)) + } + } + + export async function encode(data: any): Promise { + const form = new FormData() + const json = JSON.stringify(JsonForm.dump(data, '$', form)) + if ([...form.entries()].length) { + form.append('$', json) + const request = new Request('stub:', { + method: 'POST', + body: form, + }) + return { + body: await request.arrayBuffer(), + headers: request.headers, + } + } else { + const body = new TextEncoder().encode(json) + const headers = new Headers({ + 'content-type': 'application/json', + }) + return { body, headers } + } + } } diff --git a/packages/protocol/src/index.ts b/packages/protocol/src/index.ts index 59a831cb..db08928f 100644 --- a/packages/protocol/src/index.ts +++ b/packages/protocol/src/index.ts @@ -86,12 +86,12 @@ export const Methods: Dict = { 'guild.member.approve': Method('handleGuildMemberRequest', ['message_id', 'approve', 'comment']), } -export interface List { +export interface List { data: T[] next?: string } -export interface TwoWayList { +export interface BidiList { data: T[] prev?: string next?: string @@ -107,7 +107,7 @@ export interface Methods { sendMessage(channelId: string, content: Element.Fragment, referrer?: any, options?: SendOptions): Promise sendPrivateMessage(userId: string, content: Element.Fragment, guildId?: string, options?: SendOptions): Promise getMessage(channelId: string, messageId: string): Promise - getMessageList(channelId: string, next?: string, direction?: Direction, limit?: number, order?: Order): Promise> + getMessageList(channelId: string, next?: string, direction?: Direction, limit?: number, order?: Order): Promise> getMessageIter(channelId: string): AsyncIterable editMessage(channelId: string, messageId: string, content: Element.Fragment): Promise deleteMessage(channelId: string, messageId: string): Promise