Skip to content

Commit

Permalink
Update tiktoken-rs to 0.5.8, add Tiktoken.context_size_for_model/1
Browse files Browse the repository at this point in the history
The `Tiktoken.context_size_for_model/1` delegates to
`tiktoken_rs::model::get_context_size(model)`, except for two specific
models which are not yet properly handled in released version of
`tiktoken_rs`.
  • Loading branch information
JonathanTron committed Mar 9, 2024
1 parent 30c88c8 commit b28b408
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 58 deletions.
8 changes: 8 additions & 0 deletions lib/tiktoken.ex
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,12 @@ defmodule Tiktoken do
{:error, {:unsupported_model, model}}
end
end

# Those two can be removed when a release of tiktoken-rs > 0.5.8 is released
def context_size_for_model("gpt-3.5-turbo-1106"), do: 16_385
def context_size_for_model("gpt-4-0125-preview"), do: 128_000

def context_size_for_model(model) do
Tiktoken.Native.context_size_for_model(model)
end
end
2 changes: 2 additions & 0 deletions lib/tiktoken/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,7 @@ defmodule Tiktoken.Native do
def cl100k_encode_with_special_tokens(_input), do: err()
def cl100k_decode(_ids), do: err()

def context_size_for_model(_model), do: err()

defp err, do: :erlang.nif_error(:nif_not_loaded)
end
28 changes: 14 additions & 14 deletions native/tiktoken/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions native/tiktoken/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ path = "src/lib.rs"
crate-type = ["cdylib"]

[dependencies]
rustler = "0.30.0"
tiktoken-rs = "0.5.6"
rustler = "0.31.0"
tiktoken-rs = "0.5.8"
8 changes: 7 additions & 1 deletion native/tiktoken/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ fn cl100k_decode(ids: Vec<usize>) -> Result<String, String> {
}
}

#[rustler::nif]
fn context_size_for_model(model: &str) -> usize {
tiktoken_rs::model::get_context_size(model)
}

rustler::init!(
"Elixir.Tiktoken.Native",
[
Expand All @@ -199,6 +204,7 @@ rustler::init!(
cl100k_encode_ordinary,
cl100k_encode,
cl100k_encode_with_special_tokens,
cl100k_decode
cl100k_decode,
context_size_for_model
]
);
118 changes: 77 additions & 41 deletions test/tiktoken_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,72 @@ defmodule TiktokenTest do
use ExUnit.Case
doctest Tiktoken

@known_models [
# chat
{"gpt-3.5-turbo", Tiktoken.CL100K, 4_096},
{"gpt-3.5-turbo-0125", Tiktoken.CL100K, 4_096},
{"gpt-3.5-turbo-1106", Tiktoken.CL100K, 16_385},
{"gpt-3.5-turbo-instruct", Tiktoken.CL100K, 4_096},
{"gpt-3.5-turbo-16k", Tiktoken.CL100K, 16_384},
{"gpt-3.5-turbo-0613", Tiktoken.CL100K, 4_096},
{"gpt-3.5-turbo-16k-0613", Tiktoken.CL100K, 16_384},
{"gpt-4-0125-preview", Tiktoken.CL100K, 128_000},
{"gpt-4-turbo-preview", Tiktoken.CL100K, 8_192},
{"gpt-4-1106-preview", Tiktoken.CL100K, 128_000},
{"gpt-4-vision-preview", Tiktoken.CL100K, 8_192},
{"gpt-4-06-vision-preview", Tiktoken.CL100K, 8_192},
{"gpt-4", Tiktoken.CL100K, 8_192},
{"gpt-4-0613", Tiktoken.CL100K, 8_192},
{"gpt-4-32k", Tiktoken.CL100K, 32_768},
{"gpt-4-32k-0613", Tiktoken.CL100K, 32_768},
# text
{"text-davinci-003", Tiktoken.P50K, 4_097},
{"text-davinci-002", Tiktoken.P50K, 4_097},
{"text-davinci-001", Tiktoken.R50K, 4_096},
{"text-curie-001", Tiktoken.R50K, 2_049},
{"text-babbage-001", Tiktoken.R50K, 2_049},
{"text-ada-001", Tiktoken.R50K, 2_049},
{"davinci", Tiktoken.R50K, 2_049},
{"curie", Tiktoken.R50K, 2_049},
{"babbage", Tiktoken.R50K, 2_049},
{"ada", Tiktoken.R50K, 2_049},
# code
{"code-davinci-002", Tiktoken.P50K, 8_001},
{"code-davinci-001", Tiktoken.P50K, 4_096},
{"code-cushman-002", Tiktoken.P50K, 4_096},
{"code-cushman-001", Tiktoken.P50K, 2_048},
{"davinci-codex", Tiktoken.P50K, 2_049},
{"cushman-codex", Tiktoken.P50K, 4_096},
# edit
{"text-davinci-edit-001", Tiktoken.P50KEdit, 4_096},
{"code-davinci-edit-001", Tiktoken.P50KEdit, 4_096},
# embeddings
# {"text-embedding-3-large", Tiktoken.CL100K},
# {"text-embedding-3-small", Tiktoken.CL100K},
{"text-embedding-ada-002", Tiktoken.CL100K, 8_192},
# old embeddings
{"text-similarity-davinci-001", Tiktoken.R50K, 4_096},
{"text-similarity-curie-001", Tiktoken.R50K, 4_096},
{"text-similarity-babbage-001", Tiktoken.R50K, 4_096},
{"text-similarity-ada-001", Tiktoken.R50K, 4_096},
{"text-search-davinci-doc-001", Tiktoken.R50K, 4_096},
{"text-search-curie-doc-001", Tiktoken.R50K, 4_096},
{"text-search-babbage-doc-001", Tiktoken.R50K, 4_096},
{"text-search-ada-doc-001", Tiktoken.R50K, 4_096},
{"code-search-babbage-code-001", Tiktoken.R50K, 4_096},
{"code-search-ada-code-001", Tiktoken.R50K, 4_096}
# moderation
# {"text-moderation-latest", Tiktoken.CL100K},
# {"text-moderation-stable", Tiktoken.CL100K},
# {"text-moderation-007", Tiktoken.CL100K}
# open source
# {"gpt2", "gpt2"}
]

describe "encoding_for_model/1" do
test "get the proper module for supported model" do
[
# chat
{"gpt-3.5-turbo", Tiktoken.CL100K},
# text
{"text-davinci-003", Tiktoken.P50K},
{"text-davinci-002", Tiktoken.P50K},
{"text-davinci-001", Tiktoken.R50K},
{"text-curie-001", Tiktoken.R50K},
{"text-babbage-001", Tiktoken.R50K},
{"text-ada-001", Tiktoken.R50K},
{"davinci", Tiktoken.R50K},
{"curie", Tiktoken.R50K},
{"babbage", Tiktoken.R50K},
{"ada", Tiktoken.R50K},
# code
{"code-davinci-002", Tiktoken.P50K},
{"code-davinci-001", Tiktoken.P50K},
{"code-cushman-002", Tiktoken.P50K},
{"code-cushman-001", Tiktoken.P50K},
{"davinci-codex", Tiktoken.P50K},
{"cushman-codex", Tiktoken.P50K},
# edit
{"text-davinci-edit-001", Tiktoken.P50KEdit},
{"code-davinci-edit-001", Tiktoken.P50KEdit},
# embeddings
{"text-embedding-ada-002", Tiktoken.CL100K},
# old embeddings
{"text-similarity-davinci-001", Tiktoken.R50K},
{"text-similarity-curie-001", Tiktoken.R50K},
{"text-similarity-babbage-001", Tiktoken.R50K},
{"text-similarity-ada-001", Tiktoken.R50K},
{"text-search-davinci-doc-001", Tiktoken.R50K},
{"text-search-curie-doc-001", Tiktoken.R50K},
{"text-search-babbage-doc-001", Tiktoken.R50K},
{"text-search-ada-doc-001", Tiktoken.R50K},
{"code-search-babbage-code-001", Tiktoken.R50K},
{"code-search-ada-code-001", Tiktoken.R50K}
# open source
# {"gpt2", "gpt2"}
]
|> Enum.each(fn {model, mod} ->
@known_models
|> Enum.each(fn {model, mod, _context_size} ->
assert Tiktoken.encoding_for_model(model) == mod
end)
end
Expand Down Expand Up @@ -112,4 +135,17 @@ defmodule TiktokenTest do
Tiktoken.decode("gpt2", [1])
end
end

describe "context_size_for_model/1" do
test "get proper context size for model" do
@known_models
|> Enum.each(fn {model, _mod, context_size} ->
assert Tiktoken.context_size_for_model(model) == context_size
end)
end

test "get 4096 for unknown model" do
assert Tiktoken.context_size_for_model("unknown") == 4_096
end
end
end

0 comments on commit b28b408

Please sign in to comment.