Skip to content

Commit

Permalink
Merge pull request elizaOS#703 from antpb/switch-to-js-tiktoken
Browse files Browse the repository at this point in the history
fix: Switch from tiktoken to js-tiktoken for worker compatibility
  • Loading branch information
shakkernerd authored Nov 30, 2024
2 parents 6ae925e + a952fa7 commit 8c35b9e
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 63 deletions.
2 changes: 1 addition & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@
"gaxios": "6.7.1",
"glob": "11.0.0",
"js-sha1": "0.7.0",
"js-tiktoken": "^1.0.15",
"langchain": "0.3.6",
"ollama-ai-provider": "0.16.1",
"openai": "4.73.0",
"tiktoken": "1.0.17",
"tinyld": "1.3.4",
"together-ai": "0.7.0",
"unique-names-generator": "4.7.1",
Expand Down
12 changes: 4 additions & 8 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
import { Buffer } from "buffer";
import { createOllama } from "ollama-ai-provider";
import OpenAI from "openai";
import { encoding_for_model, TiktokenModel } from "tiktoken";
import { encodingForModel, TiktokenModel } from "js-tiktoken";
import Together from "together-ai";
import { ZodSchema } from "zod";
import { elizaLogger } from "./index.ts";
Expand Down Expand Up @@ -429,7 +429,7 @@ export function trimTokens(
if (maxTokens <= 0) throw new Error("maxTokens must be positive");

// Get the tokenizer for the model
const encoding = encoding_for_model(model);
const encoding = encodingForModel(model);

try {
// Encode the text into tokens
Expand All @@ -443,16 +443,12 @@ export function trimTokens(
// Keep the most recent tokens by slicing from the end
const truncatedTokens = tokens.slice(-maxTokens);

// Decode back to text and convert to string
const decodedText = encoding.decode(truncatedTokens);
return new TextDecoder().decode(decodedText);
// Decode back to text - js-tiktoken decode() returns a string directly
return encoding.decode(truncatedTokens);
} catch (error) {
console.error("Error in trimTokens:", error);
// Return truncated string if tokenization fails
return context.slice(-maxTokens * 4); // Rough estimate of 4 chars per token
} finally {
// Clean up tokenizer resources
encoding.free();
}
}

Expand Down
53 changes: 52 additions & 1 deletion packages/core/src/tests/generation.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { ModelProviderName, IAgentRuntime } from "../types";
import { models } from "../models";
import { generateText, generateTrueOrFalse, splitChunks } from "../generation";
import { generateText, generateTrueOrFalse, splitChunks, trimTokens } from "../generation";
import type { TiktokenModel } from "js-tiktoken";

// Mock the elizaLogger
vi.mock("../index.ts", () => ({
Expand Down Expand Up @@ -120,4 +121,54 @@ describe("Generation", () => {
expect(chunks).toEqual([content]);
});
});

describe("trimTokens", () => {
const model = "gpt-4" as TiktokenModel;

it("should return empty string for empty input", () => {
const result = trimTokens("", 100, model);
expect(result).toBe("");
});

it("should throw error for negative maxTokens", () => {
expect(() => trimTokens("test", -1, model)).toThrow("maxTokens must be positive");
});

it("should return unchanged text if within token limit", () => {
const shortText = "This is a short text";
const result = trimTokens(shortText, 10, model);
expect(result).toBe(shortText);
});

it("should truncate text to specified token limit", () => {
// Using a longer text that we know will exceed the token limit
const longText = "This is a much longer text that will definitely exceed our very small token limit and need to be truncated to fit within the specified constraints."
const result = trimTokens(longText, 5, model);

// The exact result will depend on the tokenizer, but we can verify:
// 1. Result is shorter than original
expect(result.length).toBeLessThan(longText.length);
// 2. Result is not empty
expect(result.length).toBeGreaterThan(0);
// 3. Result is a proper substring of the original text
expect(longText.includes(result)).toBe(true);
});

it("should handle non-ASCII characters", () => {
const unicodeText = "Hello 👋 World 🌍";
const result = trimTokens(unicodeText, 5, model);
expect(result.length).toBeGreaterThan(0);
});

it("should handle multiline text", () => {
const multilineText = `Line 1
Line 2
Line 3
Line 4
Line 5`;
const result = trimTokens(multilineText, 5, model);
expect(result.length).toBeGreaterThan(0);
expect(result.length).toBeLessThan(multilineText.length);
});
});
});
Loading

0 comments on commit 8c35b9e

Please sign in to comment.