Skip to content

Commit

Permalink
feat: New Setting UI.
Browse files Browse the repository at this point in the history
  • Loading branch information
Emt-lin committed Dec 20, 2024
1 parent ac0007b commit 167ceb4
Show file tree
Hide file tree
Showing 36 changed files with 2,876 additions and 43 deletions.
35 changes: 21 additions & 14 deletions src/LLMProviders/chatModelManager.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { CustomModel, getModelKey, ModelConfig, setModelKey } from "@/aiParams";
import { BUILTIN_CHAT_MODELS, ChatModelProviders } from "@/constants";
import { getDecryptedKey } from "@/encryptionService";
import { getSettings, subscribeToSettingsChange } from "@/settings/model";
import { safeFetch } from "@/utils";
import { getModelKeyFromModel, getSettings, subscribeToSettingsChange } from "@/settings/model";
import { err2String, safeFetch } from "@/utils";
import { HarmBlockThreshold, HarmCategory } from "@google/generative-ai";
import { ChatAnthropic } from "@langchain/anthropic";
import { ChatCohere } from "@langchain/cohere";
Expand Down Expand Up @@ -78,8 +78,8 @@ export default class ChatModelManager {
const isO1Model = modelName.startsWith("o1");
const baseConfig: ModelConfig = {
modelName: modelName,
temperature: settings.temperature,
streaming: true,
temperature: customModel.temperature ?? settings.temperature,
streaming: customModel.stream ?? true,
maxRetries: 3,
maxConcurrency: 3,
enableCors: customModel.enableCors,
Expand All @@ -96,7 +96,7 @@ export default class ChatModelManager {
fetch: customModel.enableCors ? safeFetch : undefined,
},
// @ts-ignore
openAIOrgId: getDecryptedKey(settings.openAIOrgId),
openAIOrgId: getDecryptedKey(customModel.openAIOrgId || settings.openAIOrgId),
...this.handleOpenAIExtraArgs(isO1Model, settings.maxTokens, settings.temperature),
},
[ChatModelProviders.ANTHROPIC]: {
Expand All @@ -111,9 +111,11 @@ export default class ChatModelManager {
},
[ChatModelProviders.AZURE_OPENAI]: {
azureOpenAIApiKey: getDecryptedKey(customModel.apiKey || settings.azureOpenAIApiKey),
azureOpenAIApiInstanceName: settings.azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: settings.azureOpenAIApiDeploymentName,
azureOpenAIApiVersion: settings.azureOpenAIApiVersion,
azureOpenAIApiInstanceName:
customModel.azureOpenAIApiInstanceName || settings.azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName:
customModel.azureOpenAIApiDeploymentName || settings.azureOpenAIApiDeploymentName,
azureOpenAIApiVersion: customModel.azureOpenAIApiVersion || settings.azureOpenAIApiVersion,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
Expand Down Expand Up @@ -224,7 +226,7 @@ export default class ChatModelManager {
const getDefaultApiKey = this.providerApiKeyMap[model.provider as ChatModelProviders];

const apiKey = model.apiKey || getDefaultApiKey();
const modelKey = `${model.name}|${model.provider}`;
const modelKey = getModelKeyFromModel(model);
modelMap[modelKey] = {
hasApiKey: Boolean(model.apiKey || apiKey),
AIConstructor: constructor,
Expand Down Expand Up @@ -252,7 +254,7 @@ export default class ChatModelManager {
}

setChatModel(model: CustomModel): void {
const modelKey = `${model.name}|${model.provider}`;
const modelKey = getModelKeyFromModel(model);
if (!ChatModelManager.modelMap.hasOwnProperty(modelKey)) {
throw new Error(`No model found for: ${modelKey}`);
}
Expand All @@ -268,7 +270,7 @@ export default class ChatModelManager {

const modelConfig = this.getModelConfig(model);

setModelKey(`${model.name}|${model.provider}`);
setModelKey(modelKey);
try {
const newModelInstance = new selectedModel.AIConstructor({
...modelConfig,
Expand Down Expand Up @@ -327,7 +329,7 @@ export default class ChatModelManager {
// First try without CORS
await tryPing(false);
return true;
} catch (error) {
} catch (firstError) {
console.log("First ping attempt failed, trying with CORS...");
try {
// Second try with CORS
Expand All @@ -337,8 +339,13 @@ export default class ChatModelManager {
);
return true;
} catch (error) {
console.error("Chat model ping failed:", error);
throw error;
const msg =
"\nwithout CORS Error: " +
err2String(firstError) +
"\nwith CORS Error: " +
err2String(error);
// console.error("Chat model ping failed:", error);
throw new Error(msg);
}
}
}
Expand Down
28 changes: 18 additions & 10 deletions src/LLMProviders/embeddingManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import { CustomModel } from "@/aiParams";
import { EmbeddingModelProviders } from "@/constants";
import { getDecryptedKey } from "@/encryptionService";
import { CustomError } from "@/error";
import { getSettings, subscribeToSettingsChange } from "@/settings/model";
import { safeFetch } from "@/utils";
import { err2String, safeFetch } from "@/utils";
import { getSettings, subscribeToSettingsChange, getModelKeyFromModel } from "@/settings/model";
import { CohereEmbeddings } from "@langchain/cohere";
import { Embeddings } from "@langchain/core/embeddings";
import { GoogleGenerativeAIEmbeddings } from "@langchain/google-genai";
Expand Down Expand Up @@ -95,7 +95,7 @@ export default class EmbeddingManager {
const apiKey =
model.apiKey || this.providerApiKeyMap[model.provider as EmbeddingModelProviders]();

const modelKey = `${model.name}|${model.provider}`;
const modelKey = getModelKeyFromModel(model);
modelMap[modelKey] = {
hasApiKey: Boolean(apiKey),
EmbeddingConstructor: constructor,
Expand All @@ -121,7 +121,7 @@ export default class EmbeddingManager {
// Get the custom model that matches the name and provider from the model key
private getCustomModel(modelKey: string): CustomModel {
return this.activeEmbeddingModels.filter((model) => {
const key = `${model.name}|${model.provider}`;
const key = getModelKeyFromModel(model);
return modelKey === key;
})[0];
}
Expand Down Expand Up @@ -186,9 +186,12 @@ export default class EmbeddingManager {
},
[EmbeddingModelProviders.AZURE_OPENAI]: {
azureOpenAIApiKey: getDecryptedKey(customModel.apiKey || settings.azureOpenAIApiKey),
azureOpenAIApiInstanceName: settings.azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: settings.azureOpenAIApiEmbeddingDeploymentName,
azureOpenAIApiVersion: settings.azureOpenAIApiVersion,
azureOpenAIApiInstanceName:
customModel.azureOpenAIApiInstanceName || settings.azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName:
customModel.azureOpenAIApiEmbeddingDeploymentName ||
settings.azureOpenAIApiEmbeddingDeploymentName,
azureOpenAIApiVersion: customModel.azureOpenAIApiVersion || settings.azureOpenAIApiVersion,
configuration: {
baseURL: customModel.baseUrl,
fetch: customModel.enableCors ? safeFetch : undefined,
Expand Down Expand Up @@ -236,7 +239,7 @@ export default class EmbeddingManager {
// First try without CORS
await tryPing(false);
return true;
} catch (error) {
} catch (firstError) {
console.log("First ping attempt failed, trying with CORS...");
try {
// Second try with CORS
Expand All @@ -246,8 +249,13 @@ export default class EmbeddingManager {
);
return true;
} catch (error) {
console.error("Embedding model ping failed:", error);
throw error;
const msg =
"\nwithout CORS Error: " +
err2String(firstError) +
"\nwith CORS Error: " +
err2String(error);
// console.error("Embedding model ping failed:", error);
throw new Error(msg);
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/aiParams.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ export interface CustomModel {
isBuiltIn?: boolean;
enableCors?: boolean;
core?: boolean;
stream?: boolean;
temperature?: number;
context?: number;
// OpenAI specific fields
openAIOrgId?: string;

// Azure OpenAI specific fields
azureOpenAIApiInstanceName?: string;
azureOpenAIApiDeploymentName?: string;
azureOpenAIApiVersion?: string;
azureOpenAIApiEmbeddingDeploymentName?: string;
}

export function setModelKey(modelKey: string) {
Expand Down
15 changes: 7 additions & 8 deletions src/components/chat-components/ChatInput.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CustomModel, useChainType, useModelKey } from "@/aiParams";
import { useChainType, useModelKey } from "@/aiParams";
import { ChainType } from "@/chainFactory";
import { AddImageModal } from "@/components/modals/AddImageModal";
import { ListPromptModal } from "@/components/modals/ListPromptModal";
Expand All @@ -7,7 +7,7 @@ import { ContextProcessor } from "@/contextProcessor";
import { CustomPromptProcessor } from "@/customPromptProcessor";
import { COPILOT_TOOL_NAMES } from "@/LLMProviders/intentAnalyzer";
import { Mention } from "@/mentions/Mention";
import { useSettingsValue } from "@/settings/model";
import { getModelKeyFromModel, useSettingsValue } from "@/settings/model";
import { ChatMessage } from "@/sharedState";
import { getToolDescription } from "@/tools/toolManager";
import { extractNoteTitles } from "@/utils";
Expand Down Expand Up @@ -40,8 +40,6 @@ interface ChatInputProps {
chatHistory: ChatMessage[];
}

const getModelKey = (model: CustomModel) => `${model.name}|${model.provider}`;

const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
(
{
Expand Down Expand Up @@ -447,8 +445,9 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
<div className="chat-input-left">
<DropdownMenu.Root open={isModelDropdownOpen} onOpenChange={setIsModelDropdownOpen}>
<DropdownMenu.Trigger className="model-select-button">
{settings.activeModels.find((model) => getModelKey(model) === currentModelKey)
?.name || "Select Model"}
{settings.activeModels.find(
(model) => getModelKeyFromModel(model) === currentModelKey
)?.name || "Select Model"}
<ChevronUp size={10} />
</DropdownMenu.Trigger>

Expand All @@ -458,8 +457,8 @@ const ChatInput = forwardRef<{ focus: () => void }, ChatInputProps>(
.filter((model) => model.enabled)
.map((model) => (
<DropdownMenu.Item
key={getModelKey(model)}
onSelect={() => setCurrentModelKey(getModelKey(model))}
key={getModelKeyFromModel(model)}
onSelect={() => setCurrentModelKey(getModelKeyFromModel(model))}
>
{model.name}
</DropdownMenu.Item>
Expand Down
50 changes: 50 additions & 0 deletions src/components/ui/button.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import * as React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cva, type VariantProps } from "class-variance-authority";

import { cn } from "@/lib/utils";

const buttonVariants = cva(
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-colors focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg]:size-4 [&_svg]:shrink-0",
{
variants: {
variant: {
default: "bg-primary text-primary-foreground shadow hover:bg-primary/90",
destructive: "bg-destructive text-destructive-foreground shadow-sm hover:bg-destructive/90",
outline:
"border border-input bg-background shadow-sm hover:bg-accent hover:text-accent-foreground",
secondary: "bg-secondary text-secondary-foreground shadow-sm hover:bg-secondary/80",
ghost: "hover:bg-accent hover:text-accent-foreground",
link: "text-primary underline-offset-4 hover:underline",
},
size: {
default: "h-9 px-4 py-2",
sm: "h-8 rounded-md px-3 text-xs",
lg: "h-10 rounded-md px-8",
icon: "h-9 w-9",
},
},
defaultVariants: {
variant: "default",
size: "default",
},
}
);

export interface ButtonProps
extends React.ButtonHTMLAttributes<HTMLButtonElement>,
VariantProps<typeof buttonVariants> {
asChild?: boolean;
}

const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
({ className, variant, size, asChild = false, ...props }, ref) => {
const Comp = asChild ? Slot : "button";
return (
<Comp className={cn(buttonVariants({ variant, size, className }))} ref={ref} {...props} />
);
}
);
Button.displayName = "Button";

export { Button, buttonVariants };
27 changes: 27 additions & 0 deletions src/components/ui/checkbox.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import * as React from "react";
import * as CheckboxPrimitive from "@radix-ui/react-checkbox";
import { Check } from "lucide-react";

import { cn } from "@/lib/utils";

const Checkbox = React.forwardRef<
React.ElementRef<typeof CheckboxPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof CheckboxPrimitive.Root>
>(({ className, ...props }, ref) => (
<CheckboxPrimitive.Root
ref={ref}
className={cn(
"checkbox-reset",
// "peer h-4 w-4 shrink-0 rounded-sm border border-primary shadow focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:cursor-not-allowed disabled:opacity-50 data-[state=checked]:bg-primary data-[state=checked]:text-primary-foreground",
className
)}
{...props}
>
<CheckboxPrimitive.Indicator className={cn("flex items-center justify-center text-current")}>
<Check className="h-4 w-4" />
</CheckboxPrimitive.Indicator>
</CheckboxPrimitive.Root>
));
Checkbox.displayName = CheckboxPrimitive.Root.displayName;

export { Checkbox };
9 changes: 9 additions & 0 deletions src/components/ui/collapsible.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import * as CollapsiblePrimitive from "@radix-ui/react-collapsible";

const Collapsible = CollapsiblePrimitive.Root;

const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger;

const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent;

export { Collapsible, CollapsibleTrigger, CollapsibleContent };
Loading

0 comments on commit 167ceb4

Please sign in to comment.