Skip to content

Commit

Permalink
ai: generate: remove endpoint from experiment & remove beta from path
Browse files Browse the repository at this point in the history
  • Loading branch information
gioelecerati committed Oct 1, 2024
1 parent 6715c2a commit 70441d7
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 61 deletions.
50 changes: 30 additions & 20 deletions packages/api/src/controllers/generate.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ afterEach(async () => {
await clearDatabase(server);
});

const testBothRoutes = (testFn) => {
describe("generate route", () => {
testFn("/generate");
});

describe("beta generate route", () => {
testFn("/beta/generate");
});
};

describe("controllers/generate", () => {
let client: TestClient;
let adminUser: User;
Expand Down Expand Up @@ -145,9 +155,9 @@ describe("controllers/generate", () => {
return form;
};

describe("API proxies", () => {
it("should call the AI Gateway for generate API /audio-to-text", async () => {
const res = await client.fetch("/beta/generate/audio-to-text", {
testBothRoutes((basePath) => {
it(`should call the AI Gateway for ${basePath}/audio-to-text`, async () => {
const res = await client.fetch(`${basePath}/audio-to-text`, {
method: "POST",
body: buildMultipartBody(
{},
Expand All @@ -162,8 +172,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "audio-to-text": 1 });
});

it("should call the AI Gateway for generate API /text-to-image", async () => {
const res = await client.post("/beta/generate/text-to-image", {
it(`should call the AI Gateway for ${basePath}/text-to-image`, async () => {
const res = await client.post(`${basePath}/text-to-image`, {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
Expand All @@ -174,8 +184,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "text-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-image", async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
it(`should call the AI Gateway for ${basePath}/image-to-image`, async () => {
const res = await client.fetch(`${basePath}/image-to-image`, {
method: "POST",
body: buildMultipartBody({
prompt: "replace the suit with a bathing suit",
Expand All @@ -189,8 +199,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "image-to-image": 1 });
});

it("should call the AI Gateway for generate API /image-to-video", async () => {
const res = await client.fetch("/beta/generate/image-to-video", {
it(`should call the AI Gateway for ${basePath}/image-to-video`, async () => {
const res = await client.fetch(`${basePath}/image-to-video`, {
method: "POST",
body: buildMultipartBody({}),
});
Expand All @@ -202,8 +212,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ "image-to-video": 1 });
});

it("should call the AI Gateway for generate API /upscale", async () => {
const res = await client.fetch("/beta/generate/upscale", {
it(`should call the AI Gateway for ${basePath}/upscale`, async () => {
const res = await client.fetch(`${basePath}/upscale`, {
method: "POST",
body: buildMultipartBody({ prompt: "enhance" }),
});
Expand All @@ -215,8 +225,8 @@ describe("controllers/generate", () => {
expect(aiGatewayCalls).toEqual({ upscale: 1 });
});

it("should call the AI Gateway for generate API /segment-anything-2", async () => {
const res = await client.fetch("/beta/generate/segment-anything-2", {
it(`should call the AI Gateway for ${basePath}/segment-anything-2`, async () => {
const res = await client.fetch(`${basePath}/segment-anything-2`, {
method: "POST",
body: buildMultipartBody({}),
});
Expand Down Expand Up @@ -260,7 +270,7 @@ describe("controllers/generate", () => {

for (const [title, input, error] of testCases) {
it(title, async () => {
const res = await client.fetch("/beta/generate/image-to-image", {
const res = await client.fetch("/generate/image-to-image", {
method: "POST",
body: input,
});
Expand All @@ -287,7 +297,7 @@ describe("controllers/generate", () => {
}

it("should log all requests to db", async () => {
const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(200);
Expand Down Expand Up @@ -325,7 +335,7 @@ describe("controllers/generate", () => {
`{"details":{"msg":"sudden error"}}`,
);

const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(500);
Expand All @@ -345,7 +355,7 @@ describe("controllers/generate", () => {
it("should log non JSON outputs as strings to db", async () => {
mockFetchHttpError(418, "text/plain", `I'm not Jason`);

const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(418);
Expand All @@ -364,7 +374,7 @@ describe("controllers/generate", () => {
mockedFetchWithTimeout.mockImplementation(() => {
throw new Error("on your face");
});
const res = await client.post("/beta/generate/text-to-image", {
const res = await client.post("/generate/text-to-image", {
prompt: "a man in a suit and tie",
});
expect(res.status).toBe(500);
Expand Down Expand Up @@ -394,10 +404,10 @@ describe("controllers/generate", () => {

const makeAiGenReq = (pipeline: (typeof pipelines)[number]) =>
pipeline === "text-to-image"
? client.post(`/beta/generate/${pipeline}`, {
? client.post(`/generate/${pipeline}`, {
prompt: "whatever you feel like",
})
: client.fetch(`/beta/generate/${pipeline}`, {
: client.fetch(`/generate/${pipeline}`, {
method: "POST",
body: buildMultipartBody(
pipeline === "image-to-video" ? {} : { prompt: "make magic" },
Expand Down
8 changes: 7 additions & 1 deletion packages/api/src/controllers/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ const aiGenerateDurationMetric = new promclient.Histogram({

const app = Router();

app.use(experimentSubjectsOnly("ai-generate"));
// TODO: Remove beta paths middleware
app.use((req, res, next) => {
if (req.path.startsWith("/beta/generate")) {
req.url = req.url.replace("/beta/generate", "/generate");
}
next();
});

const rateLimiter: RequestHandler = async (req, res, next) => {
const now = Date.now();
Expand Down
2 changes: 2 additions & 0 deletions packages/api/src/controllers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ export default {
"api-token": apiToken,
asset,
auth,
generate,
// TODO: Remove beta paths
"beta/generate": generate,
broadcaster,
clip,
Expand Down
Loading

0 comments on commit 70441d7

Please sign in to comment.