Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ai: generate: remove endpoint from experiment & remove beta from path #2318

Merged
merged 4 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions packages/api/src/controllers/generate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ afterEach(async () => {
await clearDatabase(server);
});

const testBothRoutes = (testFn) => {
describe("generate route", () => {
testFn("/generate");
});

describe("beta generate route", () => {
testFn("/beta/generate");
});
};
Comment on lines +74 to +82
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep in mind we need to update CloudFlare routes to also skip the regular worker on the new route.


describe("controllers/generate", () => {
let client: TestClient;
let adminUser: User;
Expand Down Expand Up @@ -145,9 +155,9 @@ describe("controllers/generate", () => {
return form;
};

describe("API proxies", () => {
it("should call the AI Gateway for generate API /audio-to-text", async () => {
const res = await client.fetch("/beta/generate/audio-to-text", {
testBothRoutes((basePath) => {
it(`should call the AI Gateway for ${basePath}/audio-to-text`, async () => {
const res = await client.fetch(`${basePath}/audio-to-text`, {
method: "POST",
body: buildMultipartBody(
{},
Expand All @@ -162,8 +172,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "audio-to-text": 1 });
});

it("should call the AI Gateway for generate API /text-to-image", async () => {
const res = await client.post("/beta/generate/text-to-image", {
it(`should call the AI Gateway for ${basePath}/text-to-image`, async () => {
const res = await client.post(`${basePath}/text-to-image`, {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
Expand All @@ -174,8 +184,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "text-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-image", async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
it(`should call the AI Gateway for ${basePath}/image-to-image`, async () => {
const res = await client.fetch(`${basePath}/image-to-image`, {
method: "POST",
body: buildMultipartBody({
prompt: "replace the suit with a bathing suit",
Expand All @@ -189,8 +199,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "image-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-video", async () => {
const res = await client.fetch("/beta/generate/image-to-video", {
it(`should call the AI Gateway for ${basePath}/image-to-video`, async () => {
const res = await client.fetch(`${basePath}/image-to-video`, {
method: "POST",
body: buildMultipartBody({}),
});
Expand All @@ -202,8 +212,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "image-to-video": 1 });
});

it("should call the AI Gateway for generate API /upscale", async () => {
const res = await client.fetch("/beta/generate/upscale", {
it(`should call the AI Gateway for ${basePath}/upscale`, async () => {
const res = await client.fetch(`${basePath}/upscale`, {
method: "POST",
body: buildMultipartBody({ prompt: "enhance" }),
});
Expand All @@ -215,8 +225,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ upscale: 1 });
});

it("should call the AI Gateway for generate API /segment-anything-2", async () => {
const res = await client.fetch("/beta/generate/segment-anything-2", {
it(`should call the AI Gateway for ${basePath}/segment-anything-2`, async () => {
const res = await client.fetch(`${basePath}/segment-anything-2`, {
method: "POST",
body: buildMultipartBody({}),
});
Expand Down Expand Up @@ -260,7 +270,7 @@ describe("controllers/generate", () => {

for (const [title, input, error] of testCases) {
it(title, async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
const res = await client.fetch("/generate/image-to-image", {
gioelecerati marked this conversation as resolved.
Show resolved Hide resolved
method: "POST",
body: input,
});
Expand All @@ -287,7 +297,7 @@ describe("controllers/generate", () => {
}

it("should log all requests to db", async () => {
const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
Expand Down Expand Up @@ -325,7 +335,7 @@ describe("controllers/generate", () => {
`{"details":{"msg":"sudden error"}}`,
);

const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(500);
Expand All @@ -345,7 +355,7 @@ describe("controllers/generate", () => {
it("should log non JSON outputs as strings to db", async () => {
mockFetchHttpError(418, "text/plain", `I'm not Jason`);

const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(418);
Expand All @@ -364,7 +374,7 @@ describe("controllers/generate", () => {
mockedFetchWithTimeout.mockImplementation(() => {
throw new Error("on your face");
});
const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(500);
Expand Down Expand Up @@ -394,10 +404,10 @@ describe("controllers/generate", () => {

const makeAiGenReq = (pipeline: (typeof pipelines)[number]) =>
pipeline === "text-to-image"
? client.post(`/beta/generate/${pipeline}`, {
? client.post(`/generate/${pipeline}`, {
prompt: "whatever you feel like",
})
: client.fetch(`/beta/generate/${pipeline}`, {
: client.fetch(`/generate/${pipeline}`, {
method: "POST",
body: buildMultipartBody(
pipeline === "image-to-video" ? {} : { prompt: "make magic" },
Expand Down
2 changes: 0 additions & 2 deletions packages/api/src/controllers/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ const aiGenerateDurationMetric = new promclient.Histogram({

const app = Router();

app.use(experimentSubjectsOnly("ai-generate"));
gioelecerati marked this conversation as resolved.
Show resolved Hide resolved

const rateLimiter: RequestHandler = async (req, res, next) => {
const now = Date.now();
const [[{ count, min }]] = await db.aiGenerateLog.find(
Expand Down
2 changes: 2 additions & 0 deletions packages/api/src/controllers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ export default {
"api-token": apiToken,
asset,
auth,
generate,
// TODO: Remove beta paths
"beta/generate": generate,
broadcaster,
clip,
Expand Down
137 changes: 130 additions & 7 deletions packages/api/src/schema/ai-api-schema.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
openapi: 3.1.0
paths:
/api/beta/generate/text-to-image:
/api/generate/text-to-image:
post:
tags:
- generate
Expand Down Expand Up @@ -60,7 +60,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: textToImage
/api/beta/generate/image-to-image:
/api/generate/image-to-image:
post:
tags:
- generate
Expand Down Expand Up @@ -120,7 +120,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: imageToImage
/api/beta/generate/image-to-video:
/api/generate/image-to-video:
post:
tags:
- generate
Expand Down Expand Up @@ -180,7 +180,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: imageToVideo
/api/beta/generate/upscale:
/api/generate/upscale:
post:
tags:
- generate
Expand Down Expand Up @@ -240,7 +240,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: upscale
/api/beta/generate/audio-to-text:
/api/generate/audio-to-text:
post:
tags:
- generate
Expand Down Expand Up @@ -308,7 +308,7 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: audioToText
/api/beta/generate/segment-anything-2:
/api/generate/segment-anything-2:
post:
tags:
- generate
Expand Down Expand Up @@ -368,6 +368,65 @@ paths:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: segmentAnything2
/api/generate/llm:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep in mind you might have conflict with @mjh1's #2319, but it should be solved by simply running the pull-ai-schema script again

post:
tags:
- generate
summary: LLM
description: Generate text using a language model.
operationId: genLLM
requestBody:
content:
application/x-www-form-urlencoded:
schema:
$ref: '#/components/schemas/Body_genLLM'
required: true
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/LLMResponse'
'400':
description: Bad Request
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/HTTPError'
- $ref: '#/components/schemas/studio-api-error'
'401':
description: Unauthorized
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/HTTPError'
- $ref: '#/components/schemas/studio-api-error'
'422':
description: Validation Error
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/HTTPValidationError'
- $ref: '#/components/schemas/studio-api-error'
'500':
description: Internal Server Error
content:
application/json:
schema:
oneOf:
- $ref: '#/components/schemas/HTTPError'
- $ref: '#/components/schemas/studio-api-error'
default:
description: Error
content:
application/json:
schema:
$ref: '#/components/schemas/studio-api-error'
x-speakeasy-name-override: llm
components:
schemas:
APIError:
Expand Down Expand Up @@ -414,6 +473,14 @@ components:
title: Model Id
description: Hugging Face model ID used for image generation.
default: timbrooks/instruct-pix2pix
loras:
type: string
title: Loras
description: >-
A LoRA (Low-Rank Adaptation) model and its corresponding weight for
image generation. Example: { "latent-consistency/lcm-lora-sdxl":
1.0, "nerijs/pixel-art-xl": 1.2}.
default: ''
strength:
type: number
title: Strength
Expand Down Expand Up @@ -533,6 +600,41 @@ components:
- image
title: Body_genImageToVideo
additionalProperties: false
Body_genLLM:
properties:
prompt:
type: string
title: Prompt
model_id:
type: string
title: Model Id
default: ''
system_msg:
type: string
title: System Msg
default: ''
temperature:
type: number
title: Temperature
default: 0.7
max_tokens:
type: integer
title: Max Tokens
default: 256
history:
type: string
title: History
default: '[]'
stream:
type: boolean
title: Stream
default: false
type: object
required:
- prompt
- model_id
title: Body_genLLM
additionalProperties: false
Body_genSegmentAnything2:
properties:
image:
Expand All @@ -544,7 +646,7 @@ components:
type: string
title: Model Id
description: Hugging Face model ID used for image generation.
default: 'facebook/sam2-hiera-large'
default: facebook/sam2-hiera-large
point_coords:
type: string
title: Point Coords
Expand Down Expand Up @@ -667,6 +769,19 @@ components:
- images
title: ImageResponse
description: Response model for image generation.
LLMResponse:
properties:
response:
type: string
title: Response
tokens_used:
type: integer
title: Tokens Used
type: object
required:
- response
- tokens_used
title: LLMResponse
MasksResponse:
properties:
masks:
Expand Down Expand Up @@ -734,6 +849,14 @@ components:
title: Model Id
description: Hugging Face model ID used for image generation.
default: SG161222/RealVisXL_V4.0_Lightning
loras:
type: string
title: Loras
description: >-
A LoRA (Low-Rank Adaptation) model and its corresponding weight for
image generation. Example: { "latent-consistency/lcm-lora-sdxl":
1.0, "nerijs/pixel-art-xl": 1.2}.
default: ''
prompt:
type: string
title: Prompt
Expand Down
4 changes: 2 additions & 2 deletions packages/api/src/schema/pull-ai-schema.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ const downloadAiSchema = async () => {

// patches to the paths section
schema.paths = mapObject(schema.paths, (path, value) => {
// prefix paths with /api/beta/generate
path = `/api/beta/generate${path}`;
// prefix paths with /api/generate
path = `/api/generate${path}`;
// remove security field
delete value.post.security;
// add Studio API error as oneOf to all of the error responses
Expand Down
Loading