Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gguf: better type usage #655

Merged
merged 11 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions packages/gguf/scripts/generate-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,53 @@ import { writeFileSync } from "node:fs";
const SOURCE_CPP_URL = "https://raw.githubusercontent.com/ggerganov/llama.cpp/master/llama.cpp";
const DEST_FILE_PATH = "./src/transformer-llm.ts";
const DEST_COMMON_SOURCE = `
type Attention<TArchitecture extends string> =
& { [K in \`\${TArchitecture}.attention.head_count\`]: number }
& { [K in \`\${TArchitecture}.attention.head_count_kv\`]: number }
& { [K in \`\${TArchitecture}.attention.layer_norm_epsilon\`]: number }
& { [K in \`\${TArchitecture}.attention.layer_norm_rms_epsilon\`]: number }
& { [K in \`\${TArchitecture}.attention.alibi_bias_max\`]: number }
& { [K in \`\${TArchitecture}.attention.clip_kqv\`]: number }
& { [K in \`\${TArchitecture}.attention.use_norm\`]: number };

type Rope<TArchitecture extends LLMArchitecture> =
& { [K in \`\${TArchitecture}.rope.dimension_count\`]: number }
& { [K in \`\${TArchitecture}.rope.freq_base\`]: number }
& { [K in \`\${TArchitecture}.rope.scale\`]: number }
& { [K in \`\${TArchitecture}.rope.scale_linear\`]: number };

type MOE<TArchitecture extends LLMArchitecture> =
& { [K in \`\${TArchitecture}.expert_count\`]: number }
& { [K in \`\${TArchitecture}.expert_used_count\`]: number };
/** This file is auto-generated by generate-llm.ts */

import type { ModelBase, GGUFGeneralInfo } from "./types";

type LLMBase<TArchitecture extends string> = Partial<Record<
\`\${TArchitecture}.vocab_size\`
| \`\${TArchitecture}.use_parallel_residual\`
| \`\${TArchitecture}.tensor_data_layout\`,
number
>>;

type Attention<TArchitecture extends string> = Record<
\`\${TArchitecture}.attention.head_count\`,
number
> & Partial<Record<
\`\${TArchitecture}.attention.head_count_kv\`
| \`\${TArchitecture}.attention.key_length\`
| \`\${TArchitecture}.attention.value_length\`,
number
>>;

export type TransformerLLMRopeScalingType = "none" | "linear" | "yarn";
type Rope<TArchitecture extends LLMArchitecture> = Partial<
Record<
\`\${TArchitecture}.rope.dimension_count\`
| \`\${TArchitecture}.rope.freq_base\`
| \`\${TArchitecture}.rope.scale_linear\`
| \`\${TArchitecture}.rope.scaling.factor\`
| \`\${TArchitecture}.rope.scaling.original_context_length\`,
number
>
& Record<\`\${TArchitecture}.rope.scaling.type\`, TransformerLLMRopeScalingType>
& Record<\`\${TArchitecture}.rope.finetuned\`, boolean>
>;

type MOE<TArchitecture extends LLMArchitecture> = Partial<
Record<
\`\${TArchitecture}.expert_count\`
| \`\${TArchitecture}.expert_used_count\`,
number
>
>;

export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture>
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture>
& LLMBase<TArchitecture>
& ModelBase<TArchitecture>
& MOE<TArchitecture>
& Attention<TArchitecture>
& Rope<TArchitecture>;
Expand Down Expand Up @@ -163,15 +189,11 @@ async function main() {
/////////////////////////////////////
// write result to file
const content = [
"/** This file is auto-generated by generate-llm.ts */",
"",
'import type { ModelBase } from "./types";',
"",
DEST_COMMON_SOURCE,
"export const LLM_ARCHITECTURES = [",
...archList.map((a) => `\t${JSON.stringify(a.name)},`),
"] as const;",
"type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];",
DEST_COMMON_SOURCE,
...archList.map((a) => {
let code = `export type ${a.tsName} = TransformerLLMBase<${JSON.stringify(a.name)}>`;
if (a.hparams.length) {
Expand Down
33 changes: 18 additions & 15 deletions packages/gguf/src/gguf.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,25 @@ describe("gguf", () => {
"llama.rope.dimension_count": 128,
});

const tokens = metadata["tokenizer.ggml.tokens"];
if (!Array.isArray(tokens)) {
throw new Error();
expect(metadata["tokenizer.ggml.model"]);
if (metadata["tokenizer.ggml.model"]) {
const tokens = metadata["tokenizer.ggml.tokens"];
if (!Array.isArray(tokens)) {
throw new Error();
}
expect(tokens.slice(0, 10)).toEqual([
"<unk>",
"<s>",
"</s>",
"<0x00>",
"<0x01>",
"<0x02>",
"<0x03>",
"<0x04>",
"<0x05>",
"<0x06>",
]);
}
expect(tokens.slice(0, 10)).toEqual([
"<unk>",
"<s>",
"</s>",
"<0x00>",
"<0x01>",
"<0x02>",
"<0x03>",
"<0x04>",
"<0x05>",
"<0x06>",
]);

/// Tensor infos
/// By convention we test the first and last tensor.
Expand Down
2 changes: 1 addition & 1 deletion packages/gguf/src/gguf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ export async function gguf(
offset += tensorCount.length;
const numKv = readVersionedSize(r.view, offset, version, littleEndian);
offset += numKv.length;
const metadata: GGUFMetadata = {
const metadata: GGUFMetadata<{ strict: false }> = {
version,
tensor_count: tensorCount.value,
kv_count: numKv.value,
Expand Down
82 changes: 51 additions & 31 deletions packages/gguf/src/transformer-llm.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,56 @@
/** This file is auto-generated by generate-llm.ts */

import type { ModelBase } from "./types";
import type { ModelBase, GGUFGeneralInfo } from "./types";

type LLMBase<TArchitecture extends string> = Partial<
Record<
`${TArchitecture}.vocab_size` | `${TArchitecture}.use_parallel_residual` | `${TArchitecture}.tensor_data_layout`,
number
>
>;

type Attention<TArchitecture extends string> = Record<`${TArchitecture}.attention.head_count`, number> &
Partial<
Record<
| `${TArchitecture}.attention.head_count_kv`
| `${TArchitecture}.attention.key_length`
| `${TArchitecture}.attention.value_length`,
number
>
>;

export type TransformerLLMRopeScalingType = "none" | "linear" | "yarn";
type Rope<TArchitecture extends LLMArchitecture> = Partial<
Record<
| `${TArchitecture}.rope.dimension_count`
| `${TArchitecture}.rope.freq_base`
| `${TArchitecture}.rope.scale_linear`
| `${TArchitecture}.rope.scaling.factor`
| `${TArchitecture}.rope.scaling.original_context_length`,
number
> &
Record<`${TArchitecture}.rope.scaling.type`, TransformerLLMRopeScalingType> &
Record<`${TArchitecture}.rope.finetuned`, boolean>
>;

type MOE<TArchitecture extends LLMArchitecture> = Partial<
Record<`${TArchitecture}.expert_count` | `${TArchitecture}.expert_used_count`, number>
>;

export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture> &
LLMBase<TArchitecture> &
ModelBase<TArchitecture> &
MOE<TArchitecture> &
Attention<TArchitecture> &
Rope<TArchitecture>;

export enum TransformerLLMPoolingType {
UNSPECIFIED = -1,
NONE = 0,
MEAN = 1,
CLS = 2,
}

export const LLM_ARCHITECTURES = [
"llama",
Expand Down Expand Up @@ -37,36 +87,6 @@ export const LLM_ARCHITECTURES = [
"olmo",
] as const;
type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];

type Attention<TArchitecture extends string> = { [K in `${TArchitecture}.attention.head_count`]: number } & {
[K in `${TArchitecture}.attention.head_count_kv`]: number;
} & { [K in `${TArchitecture}.attention.layer_norm_epsilon`]: number } & {
[K in `${TArchitecture}.attention.layer_norm_rms_epsilon`]: number;
} & { [K in `${TArchitecture}.attention.alibi_bias_max`]: number } & {
[K in `${TArchitecture}.attention.clip_kqv`]: number;
} & { [K in `${TArchitecture}.attention.use_norm`]: number };

type Rope<TArchitecture extends LLMArchitecture> = { [K in `${TArchitecture}.rope.dimension_count`]: number } & {
[K in `${TArchitecture}.rope.freq_base`]: number;
} & { [K in `${TArchitecture}.rope.scale`]: number } & { [K in `${TArchitecture}.rope.scale_linear`]: number };

type MOE<TArchitecture extends LLMArchitecture> = { [K in `${TArchitecture}.expert_count`]: number } & {
[K in `${TArchitecture}.expert_used_count`]: number;
};

export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture> &
MOE<TArchitecture> &
Attention<TArchitecture> &
Rope<TArchitecture>;

export enum TransformerLLMPoolingType {
UNSPECIFIED = -1,
NONE = 0,
MEAN = 1,
CLS = 2,
}

export type ArchLlama = TransformerLLMBase<"llama"> & {
"llama.attention.layer_norm_rms_epsilon": number;
};
Expand Down
55 changes: 55 additions & 0 deletions packages/gguf/src/types.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { describe, it } from "vitest";
import type { gguf } from "./gguf";
import type { GGUFMetadata, GGUFParseOutput } from "./types";

describe("gguf-types", () => {
it("gguf() type can be casted between STRICT and NON_STRICT (at compile time)", async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const result: Awaited<ReturnType<typeof gguf>> = { metadata: {} } as any;
const strictType = result as GGUFParseOutput<{ strict: true }>;
// @ts-expect-error because the key "abc" does not exist
strictType.metadata.abc = 123;
const nonStrictType = result as GGUFParseOutput<{ strict: false }>;
nonStrictType.metadata.abc = 123; // PASS, because it can be anything
// @ts-expect-error because ArrayBuffer is not a MetadataValue
nonStrictType.metadata.fff = ArrayBuffer;
});

it("GGUFType.NON_STRICT should be correct (at compile time)", async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const model: GGUFMetadata<{ strict: false }> = {} as any;
model.kv_count = 123n;
model.abc = 456; // PASS, because it can be anything
});

it("GGUFType.STRICT should be correct (at compile time)", async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const model: GGUFMetadata<{ strict: true }> = {} as any;

if (model["general.architecture"] === "whisper") {
model["encoder.whisper.block_count"] = 0;
// @ts-expect-error because it must be a number
model["encoder.whisper.block_count"] = "abc";
}

if (model["tokenizer.ggml.model"] === undefined) {
// @ts-expect-error because it's undefined
model["tokenizer.ggml.eos_token_id"] = 1;
}
if (model["tokenizer.ggml.model"] === "gpt2") {
// @ts-expect-error because it must be a number
model["tokenizer.ggml.eos_token_id"] = undefined;
model["tokenizer.ggml.eos_token_id"] = 1;
}

if (model["general.architecture"] === "mamba") {
model["mamba.ssm.conv_kernel"] = 0;
// @ts-expect-error because it must be a number
model["mamba.ssm.conv_kernel"] = "abc";
}
if (model["general.architecture"] === "llama") {
// @ts-expect-error llama does not have ssm.* keys
model["mamba.ssm.conv_kernel"] = 0;
}
});
});
75 changes: 56 additions & 19 deletions packages/gguf/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,32 @@ export enum GGUFValueType {
const ARCHITECTURES = [...LLM_ARCHITECTURES, "rwkv", "whisper"] as const;
export type Architecture = (typeof ARCHITECTURES)[number];

interface General {
"general.architecture": Architecture;
"general.name": string;
"general.file_type": number;
"general.quantization_version": number;
export interface GGUFGeneralInfo<TArchitecture extends Architecture> {
"general.architecture": TArchitecture;
"general.name"?: string;
"general.file_type"?: number;
"general.quantization_version"?: number;
}

type ModelMetadata = Whisper | RWKV | TransformerLLM;
interface NoModelMetadata {
"general.architecture"?: undefined;
}

export type ModelBase<
TArchitecture extends
| Architecture
| `encoder.${Extract<Architecture, "whisper">}`
| `decoder.${Extract<Architecture, "whisper">}`,
> = { [K in `${TArchitecture}.layer_count`]: number } & { [K in `${TArchitecture}.feed_forward_length`]: number } & {
[K in `${TArchitecture}.context_length`]: number;
} & { [K in `${TArchitecture}.embedding_length`]: number } & { [K in `${TArchitecture}.block_count`]: number };
> = Record<
| `${TArchitecture}.context_length`
| `${TArchitecture}.block_count`
| `${TArchitecture}.embedding_length`
| `${TArchitecture}.feed_forward_length`,
number
>;

/// Tokenizer

type TokenizerModel = "no_vocab" | "llama" | "gpt2" | "bert";
interface Tokenizer {
Expand All @@ -75,21 +86,47 @@ interface Tokenizer {
"tokenizer.ggml.bos_token_id": number;
"tokenizer.ggml.eos_token_id": number;
"tokenizer.ggml.add_bos_token": boolean;
"tokenizer.chat_template": string;
"tokenizer.chat_template"?: string;
}
interface NoTokenizer {
"tokenizer.ggml.model"?: undefined;
}

/// Models outside of llama.cpp: "rwkv" and "whisper"

export type RWKV = ModelBase<"rwkv"> & { "rwkv.architecture_version": number };
export type LLM = TransformerLLM | RWKV;
export type Whisper = ModelBase<"encoder.whisper"> & ModelBase<"decoder.whisper">;
export type Model = (LLM | Whisper) & Partial<Tokenizer>;
export type RWKV = GGUFGeneralInfo<"rwkv"> &
ModelBase<"rwkv"> & {
"rwkv.architecture_version": number;
};

export type GGUFMetadata = {
// TODO: whisper.cpp doesn't yet support gguf. This maybe changed in the future.
export type Whisper = GGUFGeneralInfo<"whisper"> &
ModelBase<"encoder.whisper"> &
ModelBase<"decoder.whisper"> & {
"whisper.encoder.mels_count": number;
"whisper.encoder.attention.head_count": number;
"whisper.decoder.attention.head_count": number;
};

/// Types for parse output

export interface GGUFMetadataOptions {
/**
* Enable strict type for known GGUF fields.
*
* @default true
*/
strict: boolean;
}

export type GGUFMetadata<Options extends GGUFMetadataOptions = { strict: true }> = {
version: Version;
tensor_count: bigint;
kv_count: bigint;
} & Partial<General> &
Partial<Model> &
Record<string, MetadataValue>;
} & GGUFModelKV &
(Options extends { strict: true } ? unknown : Record<string, MetadataValue>);

export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer);

export interface GGUFTensorInfo {
name: string;
Expand All @@ -99,7 +136,7 @@ export interface GGUFTensorInfo {
offset: bigint;
}

export interface GGUFParseOutput {
metadata: GGUFMetadata;
export interface GGUFParseOutput<Options extends GGUFMetadataOptions = { strict: true }> {
metadata: GGUFMetadata<Options>;
tensorInfos: GGUFTensorInfo[];
}
Loading