Skip to content

Commit

Permalink
Merge branch 'main' into vertex_ai_support_model_version
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarrazin authored Sep 27, 2024
2 parents e4f60b5 + 88e062f commit a050c81
Show file tree
Hide file tree
Showing 44 changed files with 696 additions and 462 deletions.
3 changes: 2 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,5 @@ BODY_SIZE_LIMIT=15728640
HF_ORG_ADMIN=
HF_ORG_EARLY_ACCESS=

PUBLIC_SMOOTH_UPDATES=false
PUBLIC_SMOOTH_UPDATES=false
COMMUNITY_TOOLS=false
322 changes: 165 additions & 157 deletions chart/env/prod.yaml

Large diffs are not rendered by default.

251 changes: 139 additions & 112 deletions package-lock.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"browser-image-resizer": "^2.4.1",
"date-fns": "^2.29.3",
"dotenv": "^16.0.3",
"express": "^4.19.2",
"express": "^4.21.0",
"file-type": "^19.4.1",
"google-auth-library": "^9.13.0",
"handlebars": "^4.7.8",
Expand Down
3 changes: 3 additions & 0 deletions src/lib/components/AssistantToolPicker.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@
class="w-full cursor-pointer px-3 py-2 text-left hover:bg-blue-500 hover:text-white"
>
{suggestion.displayName}
{#if suggestion.createdByName}
<span class="text-xs text-gray-500"> by {suggestion.createdByName}</span>
{/if}
</button>
{/each}
</div>
Expand Down
16 changes: 15 additions & 1 deletion src/lib/components/ModelCardMetadata.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import CarbonEarth from "~icons/carbon/earth";
import CarbonArrowUpRight from "~icons/carbon/arrow-up-right";
import BIMeta from "~icons/bi/meta";
import CarbonCode from "~icons/carbon/code";
import type { Model } from "$lib/types/Model";
export let model: Pick<Model, "name" | "datasetName" | "websiteUrl" | "modelUrl" | "datasetUrl">;
export let model: Pick<
Model,
"name" | "datasetName" | "websiteUrl" | "modelUrl" | "datasetUrl" | "hasInferenceAPI"
>;
export let variant: "light" | "dark" = "light";
</script>
Expand Down Expand Up @@ -35,6 +39,16 @@
<div class="max-sm:hidden">&nbsp;page</div></a
>
{/if}
{#if model.hasInferenceAPI}
<a
href={"https://huggingface.co/playground?modelId=" + model.name}
target="_blank"
rel="noreferrer"
class="flex items-center hover:underline"
><CarbonCode class="mr-1.5 shrink-0 text-xs text-gray-400" />
API
</a>
{/if}
{#if model.websiteUrl}
<a
href={model.websiteUrl}
Expand Down
3 changes: 1 addition & 2 deletions src/lib/components/NavMenu.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@
Assistants
</a>
{/if}
<!-- XXX: feature_flag_tools -->
{#if $page.data.user?.isEarlyAccess}
{#if $page.data.enableCommunityTools}
<a
href="{base}/tools"
class="flex h-9 flex-none items-center gap-1.5 rounded-lg pl-2.5 pr-2 text-gray-500 hover:bg-gray-100 dark:text-gray-400 dark:hover:bg-gray-700"
Expand Down
2 changes: 1 addition & 1 deletion src/lib/components/ToolLogo.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
})();
</script>

<div class="flex {sizeClass} items-center justify-center">
<div class="flex {sizeClass} relative items-center justify-center">
<svg xmlns="http://www.w3.org/2000/svg" class="absolute {sizeClass} h-full" viewBox="0 0 52 58">
<defs>
<linearGradient id="gradient-{gradientColor}" gradientTransform="rotate(90)">
Expand Down
3 changes: 1 addition & 2 deletions src/lib/components/ToolsMenu.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@
{/if}
</button>
</div>
<!-- XXX: feature_flag_tools -->
{#if $page.data.user?.isEarlyAccess}
{#if $page.data.enableCommunityTools}
<a
href="{base}/tools"
class="col-span-2 my-1 h-fit w-fit items-center justify-center rounded-full bg-purple-500/20 px-2.5 py-1.5 text-sm hover:bg-purple-500/30"
Expand Down
8 changes: 5 additions & 3 deletions src/lib/components/chat/ChatWindow.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import UploadedFile from "./UploadedFile.svelte";
import { useSettingsStore } from "$lib/stores/settings";
import type { ToolFront } from "$lib/types/Tool";
import ModelSwitch from "./ModelSwitch.svelte";
export let messages: Message[] = [];
export let loading = false;
Expand Down Expand Up @@ -279,6 +280,9 @@
on:vote
on:continue
/>
{#if isReadOnly}
<ModelSwitch {models} {currentModel} />
{/if}
</div>
{:else if pending}
<ChatMessage
Expand Down Expand Up @@ -403,9 +407,7 @@
<ChatInput value="Sorry, something went wrong. Please try again." disabled={true} />
{:else}
<ChatInput
placeholder={isReadOnly
? "This conversation is read-only. Start a new one to continue!"
: "Ask anything"}
placeholder={isReadOnly ? "This conversation is read-only." : "Ask anything"}
bind:value={message}
on:submit={handleSubmit}
on:beforeinput={(ev) => {
Expand Down
5 changes: 4 additions & 1 deletion src/lib/components/chat/FileDropzone.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
!mimeTypes.some((mimeType: string) => {
const [type, subtype] = mimeType.split("/");
const [fileType, fileSubtype] = file.type.split("/");
return type === fileType && (subtype === "*" || fileSubtype === subtype);
return (
(type === "*" || type === fileType) &&
(subtype === "*" || subtype === fileSubtype)
);
})
) {
setErrorMsg(`Some file type not supported. Only allowed: ${mimeTypes.join(", ")}`);
Expand Down
60 changes: 60 additions & 0 deletions src/lib/components/chat/ModelSwitch.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
<script lang="ts">
import { invalidateAll } from "$app/navigation";
import { page } from "$app/stores";
import { base } from "$app/paths";
import type { Model } from "$lib/types/Model";
export let models: Model[];
export let currentModel: Model;
let selectedModelId = models.map((m) => m.id).includes(currentModel.id)
? currentModel.id
: models[0].id;
async function handleModelChange() {
if (!$page.params.id) return;
try {
const response = await fetch(`${base}/conversation/${$page.params.id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ model: selectedModelId }),
});
if (!response.ok) {
throw new Error("Failed to update model");
}
await invalidateAll();
} catch (error) {
console.error(error);
}
}
</script>

<div
class="mx-auto mt-0 flex w-fit flex-col items-center justify-center gap-2 rounded-lg border border-gray-200 bg-gray-500/20 p-4 dark:border-gray-800"
>
<span>
This model is no longer available. Switch to a new one to continue this conversation:
</span>
<div class="flex items-center space-x-2">
<select
bind:value={selectedModelId}
class="rounded-md bg-gray-100 px-2 py-1 dark:bg-gray-900 max-sm:max-w-32"
>
{#each models as model}
<option value={model.id}>{model.name}</option>
{/each}
</select>
<button
on:click={handleModelChange}
disabled={selectedModelId === currentModel.id}
class="rounded-md bg-gray-100 px-2 py-1 dark:bg-gray-900"
>
Accept
</button>
</div>
</div>
33 changes: 26 additions & 7 deletions src/lib/migrations/routines/02-update-assistants-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,36 @@ const updateAssistantsModels: Migration = {
name: "Update deprecated models in assistants with the default model",
up: async () => {
const models = (await import("$lib/server/models")).models;

const oldModels = (await import("$lib/server/models")).oldModels;
const { assistants } = collections;

const modelIds = models.map((el) => el.id); // string[]
const modelIds = models.map((el) => el.id);
const defaultModelId = models[0].id;

// Find all assistants whose modelId is not in modelIds, and update it to use defaultModelId
await assistants.updateMany(
{ modelId: { $nin: modelIds } },
{ $set: { modelId: defaultModelId } }
);
// Find all assistants whose modelId is not in modelIds, and update it
const bulkOps = await assistants
.find({ modelId: { $nin: modelIds } })
.map((assistant) => {
// has an old model
let newModelId = defaultModelId;

const oldModel = oldModels.find((m) => m.id === assistant.modelId);
if (oldModel && oldModel.transferTo && !!models.find((m) => m.id === oldModel.transferTo)) {
newModelId = oldModel.transferTo;
}

return {
updateOne: {
filter: { _id: assistant._id },
update: { $set: { modelId: newModelId } },
},
};
})
.toArray();

if (bulkOps.length > 0) {
await assistants.bulkWrite(bulkOps);
}

return true;
},
Expand Down
14 changes: 10 additions & 4 deletions src/lib/server/endpoints/aws/endpointBedrock.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import { z } from "zod";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import {
BedrockRuntimeClient,
InvokeModelWithResponseStreamCommand,
} from "@aws-sdk/client-bedrock-runtime";
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
import type { EndpointMessage } from "../endpoints";
import type { MessageFile } from "$lib/types/Message";
Expand Down Expand Up @@ -40,6 +36,16 @@ export async function endpointBedrock(
): Promise<Endpoint> {
const { region, model, anthropicVersion, multimodal } =
endpointBedrockParametersSchema.parse(input);

let BedrockRuntimeClient, InvokeModelWithResponseStreamCommand;
try {
({ BedrockRuntimeClient, InvokeModelWithResponseStreamCommand } = await import(
"@aws-sdk/client-bedrock-runtime"
));
} catch (error) {
throw new Error("Failed to import @aws-sdk/client-bedrock-runtime. Make sure it's installed.");
}

const client = new BedrockRuntimeClient({
region,
});
Expand Down
2 changes: 2 additions & 0 deletions src/lib/server/endpoints/cohere/endpointCohere.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export async function endpointCohere(
});

stream = await cohere.chatStream({
forceSingleStep: true,
message: prompt,
rawPrompting: true,
model: model.id ?? model.name,
Expand All @@ -82,6 +83,7 @@ export async function endpointCohere(

stream = await cohere
.chatStream({
forceSingleStep: true,
model: model.id ?? model.name,
chatHistory: formattedMessages.slice(0, -1),
message: formattedMessages[formattedMessages.length - 1].message,
Expand Down
15 changes: 3 additions & 12 deletions src/lib/server/endpoints/tgi/endpointTgi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ export const endpointTgiParametersSchema = z.object({
supportedMimeTypes: ["image/jpeg", "image/webp"],
preferredMimeType: "image/webp",
maxSizeInMB: 5,
maxWidth: 224,
maxHeight: 224,
maxWidth: 378,
maxHeight: 980,
}),
})
.default({}),
Expand Down Expand Up @@ -81,22 +81,13 @@ export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>):
};
}

const whiteImage = {
mime: "image/png",
image: Buffer.from(
"/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/2wBDAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEBAQH/wAARCAAQABADAREAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD+/igAoAKACgD/2Q==",
"base64"
),
};

async function prepareMessage(
isMultimodal: boolean,
message: EndpointMessage,
imageProcessor: ImageProcessor
): Promise<EndpointMessage> {
if (!isMultimodal) return message;

const files = await Promise.all(message.files?.map(imageProcessor) ?? [whiteImage]);
const files = await Promise.all(message.files?.map(imageProcessor) ?? []);
const markdowns = files.map(
(file) => `![](data:${file.mime};base64,${file.image.toString("base64")})`
);
Expand Down
45 changes: 38 additions & 7 deletions src/lib/server/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import JSON5 from "json5";
import { getTokenizer } from "$lib/utils/getTokenizer";
import { logger } from "$lib/server/logger";
import { ToolResultStatus, type ToolInput } from "$lib/types/Tool";
import { isHuggingChat } from "$lib/utils/isHuggingChat";

type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;

Expand Down Expand Up @@ -253,10 +254,6 @@ const processModel = async (m: z.infer<typeof modelConfig>) => ({
parameters: { ...m.parameters, stop_sequences: m.parameters?.stop },
});

export type ProcessedModel = Awaited<ReturnType<typeof processModel>> & {
getEndpoint: () => Promise<Endpoint>;
};

const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
...m,
getEndpoint: async (): Promise<Endpoint> => {
Expand Down Expand Up @@ -316,10 +313,43 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
},
});

export const models: ProcessedModel[] = await Promise.all(
modelsRaw.map((e) => processModel(e).then(addEndpoint))
const hasInferenceAPI = async (m: Awaited<ReturnType<typeof processModel>>) => {
if (!isHuggingChat) {
return false;
}

const r = await fetch(`https://huggingface.co/api/models/${m.id}`);

if (!r.ok) {
logger.warn(`Failed to check if ${m.id} has inference API: ${r.statusText}`);
return false;
}

const json = await r.json();

if (json.cardData.inference === false) {
return false;
}

return true;
};

export const models = await Promise.all(
modelsRaw.map((e) =>
processModel(e)
.then(addEndpoint)
.then(async (m) => ({
...m,
hasInferenceAPI: await hasInferenceAPI(m),
}))
)
);

export type ProcessedModel = (typeof models)[number];

// super ugly but not sure how to make typescript happier
export const validModelIdSchema = z.enum(models.map((m) => m.id) as [string, ...string[]]);

export const defaultModel = models[0];

// Models that have been deprecated
Expand All @@ -330,6 +360,7 @@ export const oldModels = env.OLD_MODELS
id: z.string().optional(),
name: z.string().min(1),
displayName: z.string().min(1).optional(),
transferTo: validModelIdSchema.optional(),
})
)
.parse(JSON5.parse(env.OLD_MODELS))
Expand All @@ -353,5 +384,5 @@ export const smallModel = env.TASK_MODEL

export type BackendModel = Optional<
typeof defaultModel,
"preprompt" | "parameters" | "multimodal" | "unlisted" | "tools"
"preprompt" | "parameters" | "multimodal" | "unlisted" | "tools" | "hasInferenceAPI"
>;
Loading

0 comments on commit a050c81

Please sign in to comment.