Skip to content

Commit

Permalink
gguf: better type usage (#655)
Browse files Browse the repository at this point in the history
Follow up #640

Ref comments:
-
#640 (review)
by @julien-c suggests using a check `metadata["general.architecture"]
=== ...` to select the correct type
-
#640 (comment)
by @coyotte508 suggests using less generic but more verbose code

The type system introduce in this PR allows type-checking at both
compile time & runtime:

```ts
const model: GGUFMetadata<GGUFType.STRICT> = null as any;

if (model["general.architecture"] === "whisper") {
	model["encoder.whisper.block_count"] = 0;
	// @ts-expect-error because it must be a number
	model["encoder.whisper.block_count"] = "abc";
}

if (model["tokenizer.ggml.model"] === undefined) {
	// @ts-expect-error because it's undefined
	model["tokenizer.ggml.eos_token_id"] = 1;
}
if (model["tokenizer.ggml.model"] === "gpt2") {
	// @ts-expect-error because it must be a number
	model["tokenizer.ggml.eos_token_id"] = undefined;
	model["tokenizer.ggml.eos_token_id"] = 1;
}

if (model["general.architecture"] === "mamba") {
	model["mamba.ssm.conv_kernel"] = 0;
	// @ts-expect-error because it must be a number
	model["mamba.ssm.conv_kernel"] = "abc";
}
if (model["general.architecture"] === "llama") {
	// @ts-expect-error llama does not have ssm.* keys
	model["mamba.ssm.conv_kernel"] = 0;
}
```

Type checks can be disable with `GGUFMetadata<GGUFType.NON_STRICT>`
  • Loading branch information
ngxson authored May 7, 2024
1 parent 99bbf1f commit 6a036d8
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 90 deletions.
70 changes: 46 additions & 24 deletions packages/gguf/scripts/generate-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,53 @@ import { writeFileSync } from "node:fs";
const SOURCE_CPP_URL = "https://raw.githubusercontent.com/ggerganov/llama.cpp/master/llama.cpp";
const DEST_FILE_PATH = "./src/transformer-llm.ts";
const DEST_COMMON_SOURCE = `
type Attention<TArchitecture extends string> =
& { [K in \`\${TArchitecture}.attention.head_count\`]: number }
& { [K in \`\${TArchitecture}.attention.head_count_kv\`]: number }
& { [K in \`\${TArchitecture}.attention.layer_norm_epsilon\`]: number }
& { [K in \`\${TArchitecture}.attention.layer_norm_rms_epsilon\`]: number }
& { [K in \`\${TArchitecture}.attention.alibi_bias_max\`]: number }
& { [K in \`\${TArchitecture}.attention.clip_kqv\`]: number }
& { [K in \`\${TArchitecture}.attention.use_norm\`]: number };
type Rope<TArchitecture extends LLMArchitecture> =
& { [K in \`\${TArchitecture}.rope.dimension_count\`]: number }
& { [K in \`\${TArchitecture}.rope.freq_base\`]: number }
& { [K in \`\${TArchitecture}.rope.scale\`]: number }
& { [K in \`\${TArchitecture}.rope.scale_linear\`]: number };
type MOE<TArchitecture extends LLMArchitecture> =
& { [K in \`\${TArchitecture}.expert_count\`]: number }
& { [K in \`\${TArchitecture}.expert_used_count\`]: number };
/** This file is auto-generated by generate-llm.ts */
import type { ModelBase, GGUFGeneralInfo } from "./types";
type LLMBase<TArchitecture extends string> = Partial<Record<
\`\${TArchitecture}.vocab_size\`
| \`\${TArchitecture}.use_parallel_residual\`
| \`\${TArchitecture}.tensor_data_layout\`,
number
>>;
type Attention<TArchitecture extends string> = Record<
\`\${TArchitecture}.attention.head_count\`,
number
> & Partial<Record<
\`\${TArchitecture}.attention.head_count_kv\`
| \`\${TArchitecture}.attention.key_length\`
| \`\${TArchitecture}.attention.value_length\`,
number
>>;
export type TransformerLLMRopeScalingType = "none" | "linear" | "yarn";
type Rope<TArchitecture extends LLMArchitecture> = Partial<
Record<
\`\${TArchitecture}.rope.dimension_count\`
| \`\${TArchitecture}.rope.freq_base\`
| \`\${TArchitecture}.rope.scale_linear\`
| \`\${TArchitecture}.rope.scaling.factor\`
| \`\${TArchitecture}.rope.scaling.original_context_length\`,
number
>
& Record<\`\${TArchitecture}.rope.scaling.type\`, TransformerLLMRopeScalingType>
& Record<\`\${TArchitecture}.rope.finetuned\`, boolean>
>;
type MOE<TArchitecture extends LLMArchitecture> = Partial<
Record<
\`\${TArchitecture}.expert_count\`
| \`\${TArchitecture}.expert_used_count\`,
number
>
>;
export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture>
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture>
& LLMBase<TArchitecture>
& ModelBase<TArchitecture>
& MOE<TArchitecture>
& Attention<TArchitecture>
& Rope<TArchitecture>;
Expand Down Expand Up @@ -163,15 +189,11 @@ async function main() {
/////////////////////////////////////
// write result to file
const content = [
"/** This file is auto-generated by generate-llm.ts */",
"",
'import type { ModelBase } from "./types";',
"",
DEST_COMMON_SOURCE,
"export const LLM_ARCHITECTURES = [",
...archList.map((a) => `\t${JSON.stringify(a.name)},`),
"] as const;",
"type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];",
DEST_COMMON_SOURCE,
...archList.map((a) => {
let code = `export type ${a.tsName} = TransformerLLMBase<${JSON.stringify(a.name)}>`;
if (a.hparams.length) {
Expand Down
33 changes: 18 additions & 15 deletions packages/gguf/src/gguf.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,25 @@ describe("gguf", () => {
"llama.rope.dimension_count": 128,
});

const tokens = metadata["tokenizer.ggml.tokens"];
if (!Array.isArray(tokens)) {
throw new Error();
expect(metadata["tokenizer.ggml.model"]);
if (metadata["tokenizer.ggml.model"]) {
const tokens = metadata["tokenizer.ggml.tokens"];
if (!Array.isArray(tokens)) {
throw new Error();
}
expect(tokens.slice(0, 10)).toEqual([
"<unk>",
"<s>",
"</s>",
"<0x00>",
"<0x01>",
"<0x02>",
"<0x03>",
"<0x04>",
"<0x05>",
"<0x06>",
]);
}
expect(tokens.slice(0, 10)).toEqual([
"<unk>",
"<s>",
"</s>",
"<0x00>",
"<0x01>",
"<0x02>",
"<0x03>",
"<0x04>",
"<0x05>",
"<0x06>",
]);

/// Tensor infos
/// By convention we test the first and last tensor.
Expand Down
2 changes: 1 addition & 1 deletion packages/gguf/src/gguf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ export async function gguf(
offset += tensorCount.length;
const numKv = readVersionedSize(r.view, offset, version, littleEndian);
offset += numKv.length;
const metadata: GGUFMetadata = {
const metadata: GGUFMetadata<{ strict: false }> = {
version,
tensor_count: tensorCount.value,
kv_count: numKv.value,
Expand Down
82 changes: 51 additions & 31 deletions packages/gguf/src/transformer-llm.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,56 @@
/** This file is auto-generated by generate-llm.ts */

import type { ModelBase } from "./types";
import type { ModelBase, GGUFGeneralInfo } from "./types";

type LLMBase<TArchitecture extends string> = Partial<
Record<
`${TArchitecture}.vocab_size` | `${TArchitecture}.use_parallel_residual` | `${TArchitecture}.tensor_data_layout`,
number
>
>;

type Attention<TArchitecture extends string> = Record<`${TArchitecture}.attention.head_count`, number> &
Partial<
Record<
| `${TArchitecture}.attention.head_count_kv`
| `${TArchitecture}.attention.key_length`
| `${TArchitecture}.attention.value_length`,
number
>
>;

export type TransformerLLMRopeScalingType = "none" | "linear" | "yarn";
type Rope<TArchitecture extends LLMArchitecture> = Partial<
Record<
| `${TArchitecture}.rope.dimension_count`
| `${TArchitecture}.rope.freq_base`
| `${TArchitecture}.rope.scale_linear`
| `${TArchitecture}.rope.scaling.factor`
| `${TArchitecture}.rope.scaling.original_context_length`,
number
> &
Record<`${TArchitecture}.rope.scaling.type`, TransformerLLMRopeScalingType> &
Record<`${TArchitecture}.rope.finetuned`, boolean>
>;

type MOE<TArchitecture extends LLMArchitecture> = Partial<
Record<`${TArchitecture}.expert_count` | `${TArchitecture}.expert_used_count`, number>
>;

export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture> &
LLMBase<TArchitecture> &
ModelBase<TArchitecture> &
MOE<TArchitecture> &
Attention<TArchitecture> &
Rope<TArchitecture>;

export enum TransformerLLMPoolingType {
UNSPECIFIED = -1,
NONE = 0,
MEAN = 1,
CLS = 2,
}

export const LLM_ARCHITECTURES = [
"llama",
Expand Down Expand Up @@ -37,36 +87,6 @@ export const LLM_ARCHITECTURES = [
"olmo",
] as const;
type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];

type Attention<TArchitecture extends string> = { [K in `${TArchitecture}.attention.head_count`]: number } & {
[K in `${TArchitecture}.attention.head_count_kv`]: number;
} & { [K in `${TArchitecture}.attention.layer_norm_epsilon`]: number } & {
[K in `${TArchitecture}.attention.layer_norm_rms_epsilon`]: number;
} & { [K in `${TArchitecture}.attention.alibi_bias_max`]: number } & {
[K in `${TArchitecture}.attention.clip_kqv`]: number;
} & { [K in `${TArchitecture}.attention.use_norm`]: number };

type Rope<TArchitecture extends LLMArchitecture> = { [K in `${TArchitecture}.rope.dimension_count`]: number } & {
[K in `${TArchitecture}.rope.freq_base`]: number;
} & { [K in `${TArchitecture}.rope.scale`]: number } & { [K in `${TArchitecture}.rope.scale_linear`]: number };

type MOE<TArchitecture extends LLMArchitecture> = { [K in `${TArchitecture}.expert_count`]: number } & {
[K in `${TArchitecture}.expert_used_count`]: number;
};

export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture> &
MOE<TArchitecture> &
Attention<TArchitecture> &
Rope<TArchitecture>;

export enum TransformerLLMPoolingType {
UNSPECIFIED = -1,
NONE = 0,
MEAN = 1,
CLS = 2,
}

export type ArchLlama = TransformerLLMBase<"llama"> & {
"llama.attention.layer_norm_rms_epsilon": number;
};
Expand Down
55 changes: 55 additions & 0 deletions packages/gguf/src/types.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import { describe, it } from "vitest";
import type { gguf } from "./gguf";
import type { GGUFMetadata, GGUFParseOutput } from "./types";

describe("gguf-types", () => {
it("gguf() type can be casted between STRICT and NON_STRICT (at compile time)", async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const result: Awaited<ReturnType<typeof gguf>> = { metadata: {} } as any;
const strictType = result as GGUFParseOutput<{ strict: true }>;
// @ts-expect-error because the key "abc" does not exist
strictType.metadata.abc = 123;
const nonStrictType = result as GGUFParseOutput<{ strict: false }>;
nonStrictType.metadata.abc = 123; // PASS, because it can be anything
// @ts-expect-error because ArrayBuffer is not a MetadataValue
nonStrictType.metadata.fff = ArrayBuffer;
});

it("GGUFType.NON_STRICT should be correct (at compile time)", async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const model: GGUFMetadata<{ strict: false }> = {} as any;
model.kv_count = 123n;
model.abc = 456; // PASS, because it can be anything
});

it("GGUFType.STRICT should be correct (at compile time)", async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const model: GGUFMetadata<{ strict: true }> = {} as any;

if (model["general.architecture"] === "whisper") {
model["encoder.whisper.block_count"] = 0;
// @ts-expect-error because it must be a number
model["encoder.whisper.block_count"] = "abc";
}

if (model["tokenizer.ggml.model"] === undefined) {
// @ts-expect-error because it's undefined
model["tokenizer.ggml.eos_token_id"] = 1;
}
if (model["tokenizer.ggml.model"] === "gpt2") {
// @ts-expect-error because it must be a number
model["tokenizer.ggml.eos_token_id"] = undefined;
model["tokenizer.ggml.eos_token_id"] = 1;
}

if (model["general.architecture"] === "mamba") {
model["mamba.ssm.conv_kernel"] = 0;
// @ts-expect-error because it must be a number
model["mamba.ssm.conv_kernel"] = "abc";
}
if (model["general.architecture"] === "llama") {
// @ts-expect-error llama does not have ssm.* keys
model["mamba.ssm.conv_kernel"] = 0;
}
});
});
75 changes: 56 additions & 19 deletions packages/gguf/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,32 @@ export enum GGUFValueType {
const ARCHITECTURES = [...LLM_ARCHITECTURES, "rwkv", "whisper"] as const;
export type Architecture = (typeof ARCHITECTURES)[number];

interface General {
"general.architecture": Architecture;
"general.name": string;
"general.file_type": number;
"general.quantization_version": number;
export interface GGUFGeneralInfo<TArchitecture extends Architecture> {
"general.architecture": TArchitecture;
"general.name"?: string;
"general.file_type"?: number;
"general.quantization_version"?: number;
}

type ModelMetadata = Whisper | RWKV | TransformerLLM;
interface NoModelMetadata {
"general.architecture"?: undefined;
}

export type ModelBase<
TArchitecture extends
| Architecture
| `encoder.${Extract<Architecture, "whisper">}`
| `decoder.${Extract<Architecture, "whisper">}`,
> = { [K in `${TArchitecture}.layer_count`]: number } & { [K in `${TArchitecture}.feed_forward_length`]: number } & {
[K in `${TArchitecture}.context_length`]: number;
} & { [K in `${TArchitecture}.embedding_length`]: number } & { [K in `${TArchitecture}.block_count`]: number };
> = Record<
| `${TArchitecture}.context_length`
| `${TArchitecture}.block_count`
| `${TArchitecture}.embedding_length`
| `${TArchitecture}.feed_forward_length`,
number
>;

/// Tokenizer

type TokenizerModel = "no_vocab" | "llama" | "gpt2" | "bert";
interface Tokenizer {
Expand All @@ -75,21 +86,47 @@ interface Tokenizer {
"tokenizer.ggml.bos_token_id": number;
"tokenizer.ggml.eos_token_id": number;
"tokenizer.ggml.add_bos_token": boolean;
"tokenizer.chat_template": string;
"tokenizer.chat_template"?: string;
}
interface NoTokenizer {
"tokenizer.ggml.model"?: undefined;
}

/// Models outside of llama.cpp: "rwkv" and "whisper"

export type RWKV = ModelBase<"rwkv"> & { "rwkv.architecture_version": number };
export type LLM = TransformerLLM | RWKV;
export type Whisper = ModelBase<"encoder.whisper"> & ModelBase<"decoder.whisper">;
export type Model = (LLM | Whisper) & Partial<Tokenizer>;
export type RWKV = GGUFGeneralInfo<"rwkv"> &
ModelBase<"rwkv"> & {
"rwkv.architecture_version": number;
};

export type GGUFMetadata = {
// TODO: whisper.cpp doesn't yet support gguf. This maybe changed in the future.
export type Whisper = GGUFGeneralInfo<"whisper"> &
ModelBase<"encoder.whisper"> &
ModelBase<"decoder.whisper"> & {
"whisper.encoder.mels_count": number;
"whisper.encoder.attention.head_count": number;
"whisper.decoder.attention.head_count": number;
};

/// Types for parse output

export interface GGUFMetadataOptions {
/**
* Enable strict type for known GGUF fields.
*
* @default true
*/
strict: boolean;
}

export type GGUFMetadata<Options extends GGUFMetadataOptions = { strict: true }> = {
version: Version;
tensor_count: bigint;
kv_count: bigint;
} & Partial<General> &
Partial<Model> &
Record<string, MetadataValue>;
} & GGUFModelKV &
(Options extends { strict: true } ? unknown : Record<string, MetadataValue>);

export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer);

export interface GGUFTensorInfo {
name: string;
Expand All @@ -99,7 +136,7 @@ export interface GGUFTensorInfo {
offset: bigint;
}

export interface GGUFParseOutput {
metadata: GGUFMetadata;
export interface GGUFParseOutput<Options extends GGUFMetadataOptions = { strict: true }> {
metadata: GGUFMetadata<Options>;
tensorInfos: GGUFTensorInfo[];
}

0 comments on commit 6a036d8

Please sign in to comment.