Skip to content

Commit

Permalink
api: Adjust API handlers to new ai schema
Browse files Browse the repository at this point in the history
  • Loading branch information
victorges committed Sep 13, 2024
1 parent 574c82d commit 632c328
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 18 deletions.
40 changes: 23 additions & 17 deletions packages/api/src/controllers/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import sql from "sql-template-strings";
import { v4 as uuid } from "uuid";
import logger from "../logger";
import { authorizer, validateFormData, validatePost } from "../middleware";
import { defaultModels } from "../schema/pull-ai-schema";
import { AiGenerateLog } from "../schema/types";
import { db } from "../store";
import { BadRequestError } from "../store/errors";
import { fetchWithTimeout } from "../util";
import { fetchWithTimeout, kebabToCamel } from "../util";
import { experimentSubjectsOnly } from "./experiment";
import { pathJoin2 } from "./helpers";

Expand Down Expand Up @@ -170,13 +171,24 @@ function logAiGenerateRequest(

function registerGenerateHandler(
type: AiGenerateType,
defaultModel: string,
isJSONReq = false, // multipart by default
): RequestHandler {
const path = `/${type}`;
const payloadParsers = isJSONReq
? [validatePost(`${type}-payload`)]
: [multipart.any(), validateFormData(`${type}-payload`)];

let payloadParsers: RequestHandler[];
let camelType = kebabToCamel(type);
camelType = camelType[0].toUpperCase() + camelType.slice(1);
if (isJSONReq) {
payloadParsers = [validatePost(`${camelType}Params`)];
} else {
payloadParsers = [
multipart.any(),
validateFormData(`Body_gen${camelType}`),
];
}

const defaultModel = defaultModels[type];

return app.post(
path,
authorizer({}),
Expand Down Expand Up @@ -236,17 +248,11 @@ function registerGenerateHandler(
);
}

registerGenerateHandler(
"text-to-image",
"SG161222/RealVisXL_V4.0_Lightning",
true,
);
registerGenerateHandler("image-to-image", "timbrooks/instruct-pix2pix");
registerGenerateHandler(
"image-to-video",
"stabilityai/stable-video-diffusion-img2vid-xt-1-1",
);
registerGenerateHandler("upscale", "stabilityai/stable-diffusion-x4-upscaler");
registerGenerateHandler("audio-to-text", "openai/whisper-large-v3");
registerGenerateHandler("text-to-image", true);
registerGenerateHandler("image-to-image");
registerGenerateHandler("image-to-video");
registerGenerateHandler("upscale");
registerGenerateHandler("audio-to-text");
registerGenerateHandler("segment-anything-2");

export default app;
2 changes: 2 additions & 0 deletions packages/api/src/schema/db-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1453,12 +1453,14 @@ components:
- image-to-image
- image-to-video
- upscale
- segment-anything-2
request:
oneOf:
- $ref: "./ai-api-schema.yaml#/components/schemas/TextToImageParams"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToImage"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genImageToVideo"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genUpscale"
- $ref: "./ai-api-schema.yaml#/components/schemas/Body_genSegmentAnything2"
statusCode:
type: integer
description: HTTP status code received from the AI gateway
Expand Down
3 changes: 2 additions & 1 deletion packages/api/src/schema/pull-ai-schema.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import path from "path";
// This downloads the AI schema from the AI worker repo and saves in the local
// ai-api-schema.yaml file, referenced by our main api-schema.yaml file.

const defaultModels = {
export const defaultModels = {
"text-to-image": "SG161222/RealVisXL_V4.0_Lightning",
"image-to-image": "timbrooks/instruct-pix2pix",
"image-to-video": "stabilityai/stable-video-diffusion-img2vid-xt-1-1",
upscale: "stabilityai/stable-diffusion-x4-upscaler",
"audio-to-text": "openai/whisper-large-v3",
"segment-anything-2": "facebook/sam2-hiera-large:",
};
const schemaDir = path.resolve(__dirname, ".");
const aiSchemaUrl =
Expand Down

0 comments on commit 632c328

Please sign in to comment.