Skip to content

Commit

Permalink
refactor: 🎨 update model alias handling logic
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Dec 16, 2024
1 parent 8633332 commit 5057e49
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 79 deletions.
3 changes: 2 additions & 1 deletion docs/src/content/docs/reference/scripts/model-aliases.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ By default, GenAIScript supports the following model aliases:
- `large`: `gpt-4o like` model
- `small`: `gpt-4o-mini` model or similar. A smaller, cheaper faster model
- `vision`: `gpt-4o-mini`. A model that can analyze images
- `reasoning`: `o1` or `o1-preview`.
- `reasoning-small`: `o1-mini`.

The following aliases are also set so that you can override LLMs used by GenAIScript itself.

- `reasoning`: `large`. In the future, `o1` like models.
- `agent`: `large`. Model used by the Agent LLM.
- `memory`: `small`. Moel used by the agent short term memory.
49 changes: 40 additions & 9 deletions packages/cli/src/nodehost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { lstat, readFile, unlink, writeFile } from "node:fs/promises"
import { ensureDir, exists, existsSync, remove } from "fs-extra"
import { resolve, dirname } from "node:path"
import { glob } from "glob"
import { debug, error, info, isQuiet, warn } from "./log"
import { debug, error, info, warn } from "./log"
import { execa } from "execa"
import { join } from "node:path"
import { createNodePath } from "./nodepath"
Expand All @@ -17,8 +17,7 @@ import {
parseTokenFromEnv,
} from "../../core/src/connection"
import {
DEFAULT_MODEL,
DEFAULT_TEMPERATURE,
DEFAULT_LARGE_MODEL,
MODEL_PROVIDER_AZURE_OPENAI,
SHELL_EXEC_TIMEOUT,
MODEL_PROVIDER_OLLAMA,
Expand All @@ -33,6 +32,14 @@ import {
DEFAULT_VISION_MODEL,
LARGE_MODEL_ID,
SMALL_MODEL_ID,
DEFAULT_SMALL_MODEL_CANDIDATES,
DEFAULT_LARGE_MODEL_CANDIDATES,
DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
DEFAULT_VISION_MODEL_CANDIDATES,
DEFAULT_REASONING_MODEL,
DEFAULT_REASONING_SMALL_MODEL,
DEFAULT_REASONING_SMALL_MODEL_CANDIDATES,
DEFAULT_REASONING_MODEL_CANDIDATES,
} from "../../core/src/constants"
import { tryReadText } from "../../core/src/fs"
import {
Expand Down Expand Up @@ -71,7 +78,6 @@ import {
} from "../../core/src/azurecontentsafety"
import { resolveGlobalConfiguration } from "../../core/src/config"
import { HostConfiguration } from "../../core/src/hostconfiguration"
import { YAMLStringify } from "../../core/src/yaml"

class NodeServerManager implements ServerManager {
async start(): Promise<void> {
Expand Down Expand Up @@ -171,11 +177,36 @@ export class NodeHost implements RuntimeHost {
Omit<ModelConfigurations, "large" | "small" | "vision" | "embeddings">
> = {
default: {
large: { model: DEFAULT_MODEL, source: "default" },
small: { model: DEFAULT_SMALL_MODEL, source: "default" },
vision: { model: DEFAULT_VISION_MODEL, source: "default" },
embeddings: { model: DEFAULT_EMBEDDINGS_MODEL, source: "default" },
reasoning: { model: LARGE_MODEL_ID, source: "default" },
large: {
model: DEFAULT_LARGE_MODEL,
source: "default",
candidates: DEFAULT_LARGE_MODEL_CANDIDATES,
},
small: {
model: DEFAULT_SMALL_MODEL,
source: "default",
candidates: DEFAULT_SMALL_MODEL_CANDIDATES,
},
vision: {
model: DEFAULT_VISION_MODEL,
source: "default",
candidates: DEFAULT_VISION_MODEL_CANDIDATES,
},
embeddings: {
model: DEFAULT_EMBEDDINGS_MODEL,
source: "default",
candidates: DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
},
reasoning: {
model: DEFAULT_REASONING_MODEL,
source: "default",
candidates: DEFAULT_REASONING_MODEL_CANDIDATES,
},
["reasoning-small"]: {
model: DEFAULT_REASONING_SMALL_MODEL,
source: "default",
candidates: DEFAULT_REASONING_SMALL_MODEL_CANDIDATES,
},
agent: { model: LARGE_MODEL_ID, source: "default" },
memory: { model: SMALL_MODEL_ID, source: "default" },
},
Expand Down
4 changes: 2 additions & 2 deletions packages/cli/src/parse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import { YAMLParse, YAMLStringify } from "../../core/src/yaml"
import { resolveTokenEncoder } from "../../core/src/encoders"
import {
CSV_REGEX,
DEFAULT_MODEL,
DEFAULT_LARGE_MODEL,
INI_REGEX,
JSON5_REGEX,
MD_REGEX,
Expand Down Expand Up @@ -204,7 +204,7 @@ export async function parseTokens(
filesGlobs: string[],
options: { excludedFiles: string[]; model: string }
) {
const { model = DEFAULT_MODEL } = options || {}
const { model = DEFAULT_LARGE_MODEL } = options || {}
const { encode: encoder } = await resolveTokenEncoder(model)

const files = await expandFiles(filesGlobs, options?.excludedFiles)
Expand Down
24 changes: 19 additions & 5 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ export const SMALL_MODEL_ID = "small"
export const LARGE_MODEL_ID = "large"
export const VISION_MODEL_ID = "vision"
export const DEFAULT_FENCE_FORMAT: FenceFormat = "xml"
export const DEFAULT_MODEL = "openai:gpt-4o"
export const DEFAULT_MODEL_CANDIDATES = [
export const DEFAULT_LARGE_MODEL = "openai:gpt-4o"
export const DEFAULT_LARGE_MODEL_CANDIDATES = [
"azure_serverless:gpt-4o",
DEFAULT_MODEL,
DEFAULT_LARGE_MODEL,
"google:gemini-1.5-pro-latest",
"anthropic:claude-2.1",
"mistral:mistral-large-latest",
Expand All @@ -69,7 +69,7 @@ export const DEFAULT_MODEL_CANDIDATES = [
export const DEFAULT_VISION_MODEL = "openai:gpt-4o"
export const DEFAULT_VISION_MODEL_CANDIDATES = [
"azure_serverless:gpt-4o",
DEFAULT_MODEL,
DEFAULT_VISION_MODEL,
"google:gemini-1.5-flash-latest",
"anthropic:claude-2.1",
"github:gpt-4o",
Expand All @@ -91,6 +91,20 @@ export const DEFAULT_EMBEDDINGS_MODEL_CANDIDATES = [
"github:text-embedding-3-small",
"client:text-embedding-3-small",
]
export const DEFAULT_REASONING_SMALL_MODEL = "openai:o1-mini"
export const DEFAULT_REASONING_SMALL_MODEL_CANDIDATES = [
"azure_serverless:o1-mini",
DEFAULT_REASONING_SMALL_MODEL,
"github:o1-mini",
"client:o1-mini",
]
export const DEFAULT_REASONING_MODEL = "openai:o1"
export const DEFAULT_REASONING_MODEL_CANDIDATES = [
"azure_serverless:o1-preview",
DEFAULT_REASONING_MODEL,
"github:o1-preview",
"client:o1-preview",
]
export const DEFAULT_EMBEDDINGS_MODEL = "openai:text-embedding-ada-002"
export const DEFAULT_TEMPERATURE = 0.8
export const BUILTIN_PREFIX = "_builtin/"
Expand Down Expand Up @@ -329,4 +343,4 @@ export const IMAGE_DETAIL_LOW_HEIGHT = 512

export const MIN_LINE_NUMBER_LENGTH = 10

export const VSCODE_SERVER_MAX_RETRIES = 5
export const VSCODE_SERVER_MAX_RETRIES = 5
7 changes: 1 addition & 6 deletions packages/core/src/git.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
// It includes functionality to find modified files, execute Git commands, and manage branches.

import { uniq } from "es-toolkit"
import {
DEFAULT_MODEL,
GIT_DIFF_MAX_TOKENS,
GIT_IGNORE_GENAI,
GIT_LOG_COUNT,
} from "./constants"
import { GIT_DIFF_MAX_TOKENS, GIT_IGNORE_GENAI } from "./constants"
import { llmifyDiff } from "./diff"
import { resolveFileContents } from "./file"
import { readText } from "./fs"
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export interface AzureTokenResolver {
export type ModelConfiguration = Readonly<
Pick<ModelOptions, "model" | "temperature"> & {
source: "cli" | "env" | "config" | "default"
candidates?: string[]
}
>

Expand Down
71 changes: 23 additions & 48 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import { uniq } from "es-toolkit"
import {
DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
DEFAULT_MODEL_CANDIDATES,
DEFAULT_SMALL_MODEL_CANDIDATES,
DEFAULT_VISION_MODEL_CANDIDATES,
LARGE_MODEL_ID,
MODEL_PROVIDER_LLAMAFILE,
MODEL_PROVIDER_OPENAI,
SMALL_MODEL_ID,
VISION_MODEL_ID,
} from "./constants"
import { errorMessage } from "./error"
import { LanguageModelConfiguration, host, runtimeHost } from "./host"
import { AbortSignalOptions, MarkdownTrace, TraceOptions } from "./trace"
import { arrayify, assert, logVerbose, toStringList } from "./util"
import { arrayify, assert, toStringList } from "./util"

/**
* model
Expand Down Expand Up @@ -117,55 +111,32 @@ export async function resolveModelConnectionInfo(
options?: {
model?: string
token?: boolean
candidates?: string[]
} & TraceOptions &
AbortSignalOptions
): Promise<{
info: ModelConnectionInfo
configuration?: LanguageModelConfiguration
}> {
const { trace, token: askToken, signal } = options || {}
const hint = options?.model || conn.model || ""
let candidates = options?.candidates
let m = hint
if (m === SMALL_MODEL_ID) {
m = undefined
candidates ??= [
runtimeHost.modelAliases.small.model,
...DEFAULT_SMALL_MODEL_CANDIDATES,
]
} else if (m === VISION_MODEL_ID) {
m = undefined
candidates ??= [
runtimeHost.modelAliases.vision.model,
...DEFAULT_VISION_MODEL_CANDIDATES,
]
} else if (m === LARGE_MODEL_ID) {
m = undefined
candidates ??= [
runtimeHost.modelAliases.large.model,
...DEFAULT_MODEL_CANDIDATES,
]
}
candidates ??= [
runtimeHost.modelAliases.large.model,
...DEFAULT_MODEL_CANDIDATES,
]

const { modelAliases } = runtimeHost
const hint = options?.model || conn.model
// supports candidate if no model hint or hint is a model alias
const supportsCandidates = !hint || !!modelAliases[hint]
let modelId = hint || LARGE_MODEL_ID
let candidates: string[]
// recursively resolve model aliases
if (m) {
const seen = [m]
const modelAliases = runtimeHost.modelAliases
while (modelAliases[m]) {
const alias = modelAliases[m].model
if (seen.includes(alias))
{
const seen: string[] = []
while (modelAliases[modelId]) {
const { model: id, candidates: c } = modelAliases[modelId]
if (seen.includes(id))
throw new Error(
`Circular model alias: ${alias}, seen ${[...seen].join(",")}`
`Circular model alias: ${id}, seen ${[...seen].join(",")}`
)
m = alias
seen.push(m)
seen.push(modelId)
modelId = id
if (supportsCandidates) candidates = c
}
if (seen.length > 1) logVerbose(`model_aliases: ${seen.join(" -> ")}`)
}

const resolveModel = async (
Expand Down Expand Up @@ -214,10 +185,14 @@ export async function resolveModelConnectionInfo(
}
}

if (m) {
return await resolveModel(m, { withToken: askToken, reportError: true })
if (!supportsCandidates) {
return await resolveModel(modelId, {
withToken: askToken,
reportError: true,
})
} else {
for (const candidate of uniq(candidates).filter((c) => !!c)) {
candidates = uniq([modelId, ...candidates].filter((c) => !!c))
for (const candidate of candidates) {
const res = await resolveModel(candidate, {
withToken: askToken,
reportError: false,
Expand Down
6 changes: 2 additions & 4 deletions packages/core/src/testhost.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import {
import { TraceOptions } from "./trace"
import {
DEFAULT_EMBEDDINGS_MODEL,
DEFAULT_MODEL,
DEFAULT_LARGE_MODEL,
DEFAULT_SMALL_MODEL,
DEFAULT_TEMPERATURE,
DEFAULT_VISION_MODEL,
} from "./constants"
import {
Expand All @@ -38,7 +37,6 @@ import {
} from "node:path"
import { LanguageModel } from "./chat"
import { NotSupportedError } from "./error"
import { HostConfiguration } from "./hostconfiguration"
import { Project } from "./server/messages"

// Function to create a frozen object representing Node.js path methods
Expand Down Expand Up @@ -73,7 +71,7 @@ export class TestHost implements RuntimeHost {

// Default options for language models
readonly modelAliases: ModelConfigurations = {
large: { model: DEFAULT_MODEL, source: "default" },
large: { model: DEFAULT_LARGE_MODEL, source: "default" },
small: { model: DEFAULT_SMALL_MODEL, source: "default" },
vision: { model: DEFAULT_VISION_MODEL, source: "default" },
embeddings: { model: DEFAULT_EMBEDDINGS_MODEL, source: "default" },
Expand Down
4 changes: 0 additions & 4 deletions packages/core/src/vectorsearch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ export async function vectorSearch(
},
{
token: true,
candidates: [
runtimeHost.modelAliases.embeddings.model,
...DEFAULT_EMBEDDINGS_MODEL_CANDIDATES,
],
}
)
if (info.error) throw new Error(info.error)
Expand Down

0 comments on commit 5057e49

Please sign in to comment.