From b28b4080a55451b248582a2998971f80dbda71a5 Mon Sep 17 00:00:00 2001 From: Jonathan Tron Date: Sat, 9 Mar 2024 11:07:26 +0100 Subject: [PATCH] Update tiktoken-rs to 0.5.8, add `Tiktoken.context_size_for_model/1` The `Tiktoken.context_size_for_model/1` delegates to `tiktoken_rs::model::get_context_size(model)`, except for two specific models which are not yet properly handled in released version of `tiktoken_rs`. --- lib/tiktoken.ex | 8 +++ lib/tiktoken/native.ex | 2 + native/tiktoken/Cargo.lock | 28 ++++----- native/tiktoken/Cargo.toml | 4 +- native/tiktoken/src/lib.rs | 8 ++- test/tiktoken_test.exs | 118 ++++++++++++++++++++++++------------- 6 files changed, 110 insertions(+), 58 deletions(-) diff --git a/lib/tiktoken.ex b/lib/tiktoken.ex index ee27616..14a0fe7 100644 --- a/lib/tiktoken.ex +++ b/lib/tiktoken.ex @@ -38,4 +38,12 @@ defmodule Tiktoken do {:error, {:unsupported_model, model}} end end + + # Those two can be removed when a release of tiktoken-rs > 0.5.8 is released + def context_size_for_model("gpt-3.5-turbo-1106"), do: 16_385 + def context_size_for_model("gpt-4-0125-preview"), do: 128_000 + + def context_size_for_model(model) do + Tiktoken.Native.context_size_for_model(model) + end end diff --git a/lib/tiktoken/native.ex b/lib/tiktoken/native.ex index 3e5b151..a41de1d 100644 --- a/lib/tiktoken/native.ex +++ b/lib/tiktoken/native.ex @@ -32,5 +32,7 @@ defmodule Tiktoken.Native do def cl100k_encode_with_special_tokens(_input), do: err() def cl100k_decode(_ids), do: err() + def context_size_for_model(_model), do: err() + defp err, do: :erlang.nif_error(:nif_not_loaded) end diff --git a/native/tiktoken/Cargo.lock b/native/tiktoken/Cargo.lock index 9e935ce..6e888fd 100644 --- a/native/tiktoken/Cargo.lock +++ b/native/tiktoken/Cargo.lock @@ -13,9 +13,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.75" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" [[package]] name = "autocfg" @@ -25,9 +25,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" -version = "0.21.4" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "bit-set" @@ -69,9 +69,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "fancy-regex" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05" dependencies = [ "bit-set", "regex", @@ -204,9 +204,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustler" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4b4fea69e23de68c42c06769d6624d2d018da550c17244dd4b691f90ced4a7e" +checksum = "a75d458f38f550976d0e4b347ca57241c192019777e46af7af73b27783287088" dependencies = [ "lazy_static", "rustler_codegen", @@ -215,9 +215,9 @@ dependencies = [ [[package]] name = "rustler_codegen" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "406061bd07aaf052c344257afed4988c5ec8efe4d2352b4c2cf27ea7c8575b12" +checksum = "dbd46408f51c0ca6a68dc36aa4f90e3554960bd1b7cc513e6ff2ccad7dd92aff" dependencies = [ "heck", "proc-macro2", @@ -227,9 +227,9 @@ dependencies = [ [[package]] name = "rustler_sys" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a7c0740e5322b64e2b952d8f0edce5f90fcf6f6fe74cca3f6e78eb3de5ea858" +checksum = "ff76ba8524729d7c9db2b3e80f2269d1fdef39b5a60624c33fd794797e69b558" dependencies = [ "regex", "unreachable", @@ -274,9 +274,9 @@ dependencies = [ [[package]] name = "tiktoken-rs" -version = "0.5.6" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e520ded49607c6b80a4ab517f564c05e2b34f2c549dbd7b6a528caa2009dda" +checksum = "40894b788eb28bbb7e36bdc8b7b1b1488b9c93fa3730f315ab965330c94c0842" dependencies = [ "anyhow", "base64", diff --git a/native/tiktoken/Cargo.toml b/native/tiktoken/Cargo.toml index 5154920..b630fa3 100644 --- a/native/tiktoken/Cargo.toml +++ b/native/tiktoken/Cargo.toml @@ -10,5 +10,5 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -rustler = "0.30.0" -tiktoken-rs = "0.5.6" +rustler = "0.31.0" +tiktoken-rs = "0.5.8" diff --git a/native/tiktoken/src/lib.rs b/native/tiktoken/src/lib.rs index 103eb63..3a3defd 100644 --- a/native/tiktoken/src/lib.rs +++ b/native/tiktoken/src/lib.rs @@ -180,6 +180,11 @@ fn cl100k_decode(ids: Vec) -> Result { } } +#[rustler::nif] +fn context_size_for_model(model: &str) -> usize { + tiktoken_rs::model::get_context_size(model) +} + rustler::init!( "Elixir.Tiktoken.Native", [ @@ -199,6 +204,7 @@ rustler::init!( cl100k_encode_ordinary, cl100k_encode, cl100k_encode_with_special_tokens, - cl100k_decode + cl100k_decode, + context_size_for_model ] ); diff --git a/test/tiktoken_test.exs b/test/tiktoken_test.exs index b9694cb..bae2fc6 100644 --- a/test/tiktoken_test.exs +++ b/test/tiktoken_test.exs @@ -2,49 +2,72 @@ defmodule TiktokenTest do use ExUnit.Case doctest Tiktoken + @known_models [ + # chat + {"gpt-3.5-turbo", Tiktoken.CL100K, 4_096}, + {"gpt-3.5-turbo-0125", Tiktoken.CL100K, 4_096}, + {"gpt-3.5-turbo-1106", Tiktoken.CL100K, 16_385}, + {"gpt-3.5-turbo-instruct", Tiktoken.CL100K, 4_096}, + {"gpt-3.5-turbo-16k", Tiktoken.CL100K, 16_384}, + {"gpt-3.5-turbo-0613", Tiktoken.CL100K, 4_096}, + {"gpt-3.5-turbo-16k-0613", Tiktoken.CL100K, 16_384}, + {"gpt-4-0125-preview", Tiktoken.CL100K, 128_000}, + {"gpt-4-turbo-preview", Tiktoken.CL100K, 8_192}, + {"gpt-4-1106-preview", Tiktoken.CL100K, 128_000}, + {"gpt-4-vision-preview", Tiktoken.CL100K, 8_192}, + {"gpt-4-06-vision-preview", Tiktoken.CL100K, 8_192}, + {"gpt-4", Tiktoken.CL100K, 8_192}, + {"gpt-4-0613", Tiktoken.CL100K, 8_192}, + {"gpt-4-32k", Tiktoken.CL100K, 32_768}, + {"gpt-4-32k-0613", Tiktoken.CL100K, 32_768}, + # text + {"text-davinci-003", Tiktoken.P50K, 4_097}, + {"text-davinci-002", Tiktoken.P50K, 4_097}, + {"text-davinci-001", Tiktoken.R50K, 4_096}, + {"text-curie-001", Tiktoken.R50K, 2_049}, + {"text-babbage-001", Tiktoken.R50K, 2_049}, + {"text-ada-001", Tiktoken.R50K, 2_049}, + {"davinci", Tiktoken.R50K, 2_049}, + {"curie", Tiktoken.R50K, 2_049}, + {"babbage", Tiktoken.R50K, 2_049}, + {"ada", Tiktoken.R50K, 2_049}, + # code + {"code-davinci-002", Tiktoken.P50K, 8_001}, + {"code-davinci-001", Tiktoken.P50K, 4_096}, + {"code-cushman-002", Tiktoken.P50K, 4_096}, + {"code-cushman-001", Tiktoken.P50K, 2_048}, + {"davinci-codex", Tiktoken.P50K, 2_049}, + {"cushman-codex", Tiktoken.P50K, 4_096}, + # edit + {"text-davinci-edit-001", Tiktoken.P50KEdit, 4_096}, + {"code-davinci-edit-001", Tiktoken.P50KEdit, 4_096}, + # embeddings + # {"text-embedding-3-large", Tiktoken.CL100K}, + # {"text-embedding-3-small", Tiktoken.CL100K}, + {"text-embedding-ada-002", Tiktoken.CL100K, 8_192}, + # old embeddings + {"text-similarity-davinci-001", Tiktoken.R50K, 4_096}, + {"text-similarity-curie-001", Tiktoken.R50K, 4_096}, + {"text-similarity-babbage-001", Tiktoken.R50K, 4_096}, + {"text-similarity-ada-001", Tiktoken.R50K, 4_096}, + {"text-search-davinci-doc-001", Tiktoken.R50K, 4_096}, + {"text-search-curie-doc-001", Tiktoken.R50K, 4_096}, + {"text-search-babbage-doc-001", Tiktoken.R50K, 4_096}, + {"text-search-ada-doc-001", Tiktoken.R50K, 4_096}, + {"code-search-babbage-code-001", Tiktoken.R50K, 4_096}, + {"code-search-ada-code-001", Tiktoken.R50K, 4_096} + # moderation + # {"text-moderation-latest", Tiktoken.CL100K}, + # {"text-moderation-stable", Tiktoken.CL100K}, + # {"text-moderation-007", Tiktoken.CL100K} + # open source + # {"gpt2", "gpt2"} + ] + describe "encoding_for_model/1" do test "get the proper module for supported model" do - [ - # chat - {"gpt-3.5-turbo", Tiktoken.CL100K}, - # text - {"text-davinci-003", Tiktoken.P50K}, - {"text-davinci-002", Tiktoken.P50K}, - {"text-davinci-001", Tiktoken.R50K}, - {"text-curie-001", Tiktoken.R50K}, - {"text-babbage-001", Tiktoken.R50K}, - {"text-ada-001", Tiktoken.R50K}, - {"davinci", Tiktoken.R50K}, - {"curie", Tiktoken.R50K}, - {"babbage", Tiktoken.R50K}, - {"ada", Tiktoken.R50K}, - # code - {"code-davinci-002", Tiktoken.P50K}, - {"code-davinci-001", Tiktoken.P50K}, - {"code-cushman-002", Tiktoken.P50K}, - {"code-cushman-001", Tiktoken.P50K}, - {"davinci-codex", Tiktoken.P50K}, - {"cushman-codex", Tiktoken.P50K}, - # edit - {"text-davinci-edit-001", Tiktoken.P50KEdit}, - {"code-davinci-edit-001", Tiktoken.P50KEdit}, - # embeddings - {"text-embedding-ada-002", Tiktoken.CL100K}, - # old embeddings - {"text-similarity-davinci-001", Tiktoken.R50K}, - {"text-similarity-curie-001", Tiktoken.R50K}, - {"text-similarity-babbage-001", Tiktoken.R50K}, - {"text-similarity-ada-001", Tiktoken.R50K}, - {"text-search-davinci-doc-001", Tiktoken.R50K}, - {"text-search-curie-doc-001", Tiktoken.R50K}, - {"text-search-babbage-doc-001", Tiktoken.R50K}, - {"text-search-ada-doc-001", Tiktoken.R50K}, - {"code-search-babbage-code-001", Tiktoken.R50K}, - {"code-search-ada-code-001", Tiktoken.R50K} - # open source - # {"gpt2", "gpt2"} - ] - |> Enum.each(fn {model, mod} -> + @known_models + |> Enum.each(fn {model, mod, _context_size} -> assert Tiktoken.encoding_for_model(model) == mod end) end @@ -112,4 +135,17 @@ defmodule TiktokenTest do Tiktoken.decode("gpt2", [1]) end end + + describe "context_size_for_model/1" do + test "get proper context size for model" do + @known_models + |> Enum.each(fn {model, _mod, context_size} -> + assert Tiktoken.context_size_for_model(model) == context_size + end) + end + + test "get 4096 for unknown model" do + assert Tiktoken.context_size_for_model("unknown") == 4_096 + end + end end