From d96f4e3dd5bd03eaa9dacfbf2570b9b163045c93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 29 Nov 2024 11:47:27 +0100 Subject: [PATCH 01/10] Refactor the llama.cpp interface --- outlines/models/__init__.py | 2 +- outlines/models/llamacpp.py | 311 ++++------------ tests/generate/test_integration_llamacpp.py | 374 -------------------- tests/models/test_llamacpp.py | 118 ++++++ 4 files changed, 179 insertions(+), 626 deletions(-) delete mode 100644 tests/generate/test_integration_llamacpp.py create mode 100644 tests/models/test_llamacpp.py diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index d51e5e483..7bfb09fb0 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -11,7 +11,7 @@ from .anthropic import Anthropic from .exllamav2 import ExLlamaV2Model, exl2 from .gemini import Gemini -from .llamacpp import LlamaCpp, llamacpp +from .llamacpp import LlamaCpp from .mlxlm import MLXLM, mlxlm from .openai import AzureOpenAI, OpenAI from .transformers import Transformers, TransformerTokenizer, mamba, transformers diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 904b193c4..0e597c5a7 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -1,25 +1,11 @@ -import dataclasses import pickle import warnings -from typing import ( - TYPE_CHECKING, - Dict, - Iterator, - List, - Optional, - Set, - Tuple, - TypedDict, - Union, -) - -from typing_extensions import Unpack - -from outlines.generate.api import GenerationParameters, SamplingParameters +from typing import TYPE_CHECKING, Dict, Iterator, List, Set, Tuple, Union + from outlines.models.tokenizer import Tokenizer if TYPE_CHECKING: - from llama_cpp import Llama, LogitsProcessorList + from llama_cpp import Llama class LlamaCppTokenizer(Tokenizer): @@ -107,185 +93,77 @@ def __setstate__(self, state): raise NotImplementedError("Cannot load a pickled llamacpp tokenizer") -class LlamaCppParams(TypedDict, total=False): - suffix: Optional[str] - temperature: float - top_p: float - min_p: float - typical_p: float - seed: int - max_tokens: int - logits_processor: "LogitsProcessorList" - stop: Optional[Union[str, List[str]]] - frequence_penalty: float - presence_penalty: float - repeat_penalty: float - top_k: int - tfs_z: float - mirostat_mode: int - mirostat_tau: float - mirostat_eta: float - stream: bool - - class LlamaCpp: - """Represents a model provided by the `llama-cpp-python` library. - - We wrap models from model providing libraries in order to give all of - them the same interface in Outlines and allow users to easily switch - between providers. This class wraps the `llama_cpp.Llama` class from the - `llama-cpp-python` library. - - """ - - def __init__(self, model: "Llama"): - self.model = model + """Wraps a model provided by the `llama-cpp-python` library.""" - @property - def tokenizer(self): - return LlamaCppTokenizer(self.model) + def __init__(self, model_path: Union[str, "Llama"], **kwargs): + from llama_cpp import Llama - def prepare_generation_parameters( - self, - generation_parameters: GenerationParameters, - sampling_parameters: SamplingParameters, - structure_logits_processor, - **llama_cpp_params: Unpack[LlamaCppParams], - ): - """Prepare the generation parameters. + if isinstance(model_path, Llama): + self.model = model_path + else: + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + if "tokenizer" not in kwargs: + warnings.warn( + "The pre-tokenizer in `llama.cpp` handles unicode improperly " + + "(https://github.com/ggerganov/llama.cpp/pull/5613)\n" + + "Outlines may raise a `RuntimeError` when building the regex index.\n" + + "To circumvent this error when using `models.llamacpp()` you may pass the argument" + + "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained()`\n" + ) - `llama-cpp-python` uses different default values + self.model = Llama(model_path, **kwargs) - """ - from llama_cpp import LogitsProcessorList + self.tokenizer = LlamaCppTokenizer(self.model) - max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) - - # We update `llama_cpp_params` with the values the user passed to the - # generator. - if "stop" not in llama_cpp_params: - llama_cpp_params["stop"] = stop_at - if "seed" not in llama_cpp_params: - llama_cpp_params["seed"] = seed - - # Somehow `llama-cpp-python` generates `max_tokens + 1` tokens - if "max_tokens" not in llama_cpp_params: - if max_tokens is None: - llama_cpp_params["max_tokens"] = -1 # indicates unlimited tokens - else: - llama_cpp_params["max_tokens"] = max_tokens - 1 - else: - llama_cpp_params["max_tokens"] = llama_cpp_params["max_tokens"] - 1 - - sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( - sampling_parameters - ) + @classmethod + def from_pretrained(cls, repo_id, filename, **kwargs): + """Download the model weights from Hugging Face and create a `Llama` instance""" + from llama_cpp import Llama - # We update the `llama_cpp_params` with the sampling values that - # were specified by the user via the `Sampler` class, unless they - # are also specified in `llama_cpp_params`. We also disable other - # sampling methods that are enabled by default and reset the temperature - # value. - # - # See https://github.com/ggerganov/llama.cpp/blob/e11a8999b5690f810c2c99c14347f0834e68c524/common/sampling.h#L22 - # for the default values in `llama.cpp` and indications to disable the sampling modes. - # Mirostat sampling, tail-free sampling and all penalties are disabled by default. - # - # See https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__ - # for default values in `llama-cpp-python` - if sampler == "beam_search": - raise NotImplementedError( - "The `llama_cpp_python` library does not support Beam Search." - ) - if num_samples != 1: - raise NotImplementedError( - "The `llama_cpp_python` library does not allow to take several samples." + # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved + if "tokenizer" not in kwargs: + warnings.warn( + "The pre-tokenizer in `llama.cpp` handles unicode improperly " + + "(https://github.com/ggerganov/llama.cpp/pull/5613)\n" + + "Outlines may raise a `RuntimeError` when building the regex index.\n" + + "To circumvent this error when using `models.llamacpp()` you may pass the argument" + + "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained()`\n" ) - if "top_p" not in llama_cpp_params: - if top_p is not None: - llama_cpp_params["top_p"] = top_p - else: - llama_cpp_params["top_p"] = 1.0 - - if "min_p" not in llama_cpp_params: - llama_cpp_params["min_p"] = 0.0 - - if "top_k" not in llama_cpp_params: - if top_k is not None: - llama_cpp_params["top_k"] = top_k - else: - llama_cpp_params["top_k"] = -1 - - if "temperature" not in llama_cpp_params: - if temperature is not None: - llama_cpp_params["temperature"] = temperature - else: - llama_cpp_params["temperature"] = 1.0 - - if "repeat_penalty" not in llama_cpp_params: - llama_cpp_params["repeat_penalty"] = 1.0 - - # The choice to stream or not should happen via the high-level API - llama_cpp_params["stream"] = False - - if structure_logits_processor is not None: - if "logits_processor" in llama_cpp_params: - llama_cpp_params["logits_processor"].append(structure_logits_processor) - else: - llama_cpp_params["logits_processor"] = LogitsProcessorList( - [structure_logits_processor] - ) - return llama_cpp_params + model = Llama.from_pretrained(repo_id, filename, **kwargs) + return cls(model) - def generate( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - structure_logits_processor, - sampling_parameters: SamplingParameters, - **llama_cpp_params: Unpack[LlamaCppParams], - ) -> str: + def generate(self, prompt: str, logits_processor, **inference_kwargs) -> str: """Generate text using `llama-cpp-python`. Arguments --------- - prompts - A prompt or list of prompts. - generation_parameters - An instance of `GenerationParameters` that contains the prompt, - the maximum number of tokens, stop sequences and seed. All the - arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + prompt + A prompt. logits_processor The logits processor to use when generating text. - sampling_parameters - An instance of `SamplingParameters`, a dataclass that contains - the name of the sampler to use and related parameters as available - in Outlines. - llama_cpp_params - Keyword arguments that can be passed to - `llama_cpp_python.Llama.__call__`. The values in `llama_cpp_params` - supersede the values of the parameters in `generation_parameters` and - `sampling_parameters`. See the `llama_cpp_python` documentation for - a list of possible values: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__ + inference_kwargs + The inference kwargs that can be passed to the `Llama.__call__` method + in the `llama-cpp-python` library. Returns ------- The generated text. """ - if not isinstance(prompts, str): + from llama_cpp import LogitsProcessorList + + if not isinstance(prompt, str): raise NotImplementedError( "The `llama-cpp-python` library does not support batch inference." ) - llama_cpp_params = self.prepare_generation_parameters( - generation_parameters, - sampling_parameters, - structure_logits_processor, - **llama_cpp_params, + completion = self.model( + prompt, + logits_processor=LogitsProcessorList([logits_processor]), + **inference_kwargs, ) - completion = self.model(prompts, **llama_cpp_params) result = completion["choices"][0]["text"] self.model.reset() @@ -293,55 +171,38 @@ def generate( return result def stream( - self, - prompts: Union[str, List[str]], - generation_parameters: GenerationParameters, - structure_logits_processor, - sampling_parameters: SamplingParameters, - **llama_cpp_params: Unpack[LlamaCppParams], + self, prompt: str, logits_processor, **inference_kwargs ) -> Iterator[str]: """Stream text using `llama-cpp-python`. Arguments --------- - prompts - A prompt or list of prompts. - generation_parameters - An instance of `GenerationParameters` that contains the prompt, - the maximum number of tokens, stop sequences and seed. All the - arguments to `SequenceGeneratorAdapter`'s `__cal__` method. + prompt + A prompt. logits_processor The logits processor to use when generating text. - sampling_parameters - An instance of `SamplingParameters`, a dataclass that contains - the name of the sampler to use and related parameters as available - in Outlines. - llama_cpp_params - Keyword arguments that can be passed to - `llama_cpp_python.Llama.__call__`. The values in `llama_cpp_params` - supersede the values of the parameters in `generation_parameters` and - `sampling_parameters`. See the `llama_cpp_python` documentation for - a list of possible values: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__ + inference_kwargs + The inference kwargs that can be passed to the `Llama.__call__` method + in the `llama-cpp-python` library. Returns ------- A generator that return strings. """ + from llama_cpp import LogitsProcessorList - if not isinstance(prompts, str): + if not isinstance(prompt, str): raise NotImplementedError( "The `llama-cpp-python` library does not support batch inference." ) - llama_cpp_params = self.prepare_generation_parameters( - generation_parameters, - sampling_parameters, - structure_logits_processor, - **llama_cpp_params, + generator = self.model( + prompt, + logits_processor=LogitsProcessorList([logits_processor]), + stream=True, + **inference_kwargs, ) - llama_cpp_params["stream"] = True - generator = self.model(prompts, **llama_cpp_params) def token_generator() -> Iterator[str]: while True: @@ -353,55 +214,3 @@ def token_generator() -> Iterator[str]: return return token_generator() - - def load_lora(self, adapter_path: str): - if self.model._model.apply_lora_from_file( - adapter_path, - 1.0, - ): - raise RuntimeError(f"Failed to apply LoRA from lora path: {adapter_path}") - - -def llamacpp( - repo_id: str, filename: Optional[str] = None, **llamacpp_model_params -) -> LlamaCpp: - """Load a model from the `llama-cpp-python` library. - - We use the `Llama.from_pretrained` classmethod that downloads models - directly from the HuggingFace hub, instead of asking users to specify - a path to the downloaded model. One can still load a local model - by initializing `llama_cpp.Llama` directly. - - Arguments - --------- - repo_id - The name of the model repository. - filename: - A filename of glob pattern to match the model file in the repo. - llama_cpp_model_params - Llama-specific model parameters. See the `llama-cpp-python` documentation - for the full list: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__ - - """ - from llama_cpp import Llama - - # Default to using the model's full context length - if "n_ctx" not in llamacpp_model_params: - llamacpp_model_params["n_ctx"] = 0 - - if "verbose" not in llamacpp_model_params: - llamacpp_model_params["verbose"] = False - - # TODO: Remove when https://github.com/ggerganov/llama.cpp/pull/5613 is resolved - if "tokenizer" not in llamacpp_model_params: - warnings.warn( - "The pre-tokenizer in `llama.cpp` handles unicode improperly " - + "(https://github.com/ggerganov/llama.cpp/pull/5613)\n" - + "Outlines may raise a `RuntimeError` when building the regex index.\n" - + "To circumvent this error when using `models.llamacpp()` you may pass the argument" - + "`tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained()`\n" - ) - - model = Llama.from_pretrained(repo_id, filename, **llamacpp_model_params) - - return LlamaCpp(model) diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py deleted file mode 100644 index fd5be2171..000000000 --- a/tests/generate/test_integration_llamacpp.py +++ /dev/null @@ -1,374 +0,0 @@ -import datetime -import re - -import pytest -from pydantic import BaseModel, constr - -import outlines.generate as generate -import outlines.models as models -import outlines.samplers as samplers - -TEST_MODEL = "./llama-test-model/TinyMistral-248M-v2-Instruct.Q4_K_M.gguf" - - -@pytest.fixture(scope="session") -def model(tmp_path_factory): - return models.llamacpp( - repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", - filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", - ) - - -@pytest.mark.parametrize( - "generator_type,params", - ( - (generate.text, []), - (generate.regex, ("[0-9]",)), - # (generate.cfg, (grammars.arithmetic,)), # Awaiting CFG fix - ), -) -def test_llamacpp_generation_api(model, generator_type, params): - generator = generator_type(model, *params) - - res = generator("test", max_tokens=10) - assert isinstance(res, str) - - res = generator("test", max_tokens=10) - assert isinstance(res, str) - - res = generator("test", stop_at=".") - assert isinstance(res, str) - - res = generator("test", stop_at=[".", "ab"]) - assert isinstance(res, str) - - res = generator("test", stop_at=[".", "ab"]) - assert isinstance(res, str) - - res1 = generator("test", seed=1, max_tokens=10) - res2 = generator("test", seed=1, max_tokens=10) - assert isinstance(res1, str) - assert isinstance(res2, str) - assert res1 == res2 - - -def test_llama_cpp_streaming_api(model): - generator = generate.text(model) - token_generator = generator.stream("test", max_tokens=10) - tokens = [token for token in token_generator] - assert len(tokens) <= 10 - assert isinstance(tokens[0], str) - - -@pytest.mark.xfail(reason="Batch inference is not available in `llama-cpp-python`.") -def test_llamacpp_batch_inference(model): - generator = generate.text(model) - res = generator(["test", "test1"]) - assert len(res) == 2 - - -def test_llamacpp_sampling_params(model): - generator = generate.text(model) - - params = { - "frequency_penalty": 1.0, - "presence_penalty": 1.0, - } - res = generator("test", seed=1, max_tokens=10, **params) - assert isinstance(res, str) - - -def test_llamacpp_greedy_sampling(model): - sampler = samplers.greedy() - generator = generate.text(model, sampler) - res = generator("test", max_tokens=20) - assert isinstance(res, str) - - -def test_llamacpp_multinomial_sampling(model): - sampler = samplers.multinomial() - generator = generate.text(model, sampler) - res = generator("test", max_tokens=10) - assert isinstance(res, str) - - sampler = samplers.multinomial(1, temperature=1.0) - generator = generate.text(model, sampler) - res = generator("test", max_tokens=10) - assert isinstance(res, str) - - sampler = samplers.multinomial(1, top_k=1) - generator = generate.text(model, sampler) - res = generator("test", max_tokens=10) - assert isinstance(res, str) - - sampler = samplers.multinomial(1, top_p=0.5) - generator = generate.text(model, sampler) - res = generator("test", max_tokens=10) - assert isinstance(res, str) - - -def test_llamacpp_several_samples(model): - sampler = samplers.multinomial(3) - generator = generate.text(model, sampler) - with pytest.raises(NotImplementedError, match="allow to take several samples"): - generator("test") - - -def test_llamacpp_beam_search(model): - sampler = samplers.beam_search(1) - generator = generate.text(model, sampler) - - with pytest.raises(NotImplementedError, match="does not support Beam Search"): - generator("test") - - -def test_llamacpp_text_stop(model): - prompt = ( - "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" - ) - sequence = generate.text(model)(prompt, stop_at="a", max_tokens=100) - assert isinstance(sequence, str) - assert sequence.find("a") == -1 - - -def test_llamacpp_regex(model): - prompt = ( - "<|im_start|>user\nWrite an email address<|im_end|>\n<|im_start|>assistant\n" - ) - regex_str = r"([a-z]{10})@([a-z]{5})\.([a-z]{3})" - generator = generate.regex(model, regex_str) - - # One prompt - sequence = generator(prompts=prompt) - assert isinstance(sequence, str) - assert re.fullmatch(pattern=regex_str, string=sequence) is not None - - -def test_llamacpp_integer(model): - prompt = ( - "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" - ) - sequence = generate.format(model, int)(prompt, max_tokens=10) - assert isinstance(sequence, int) - assert sequence != "" - int(sequence) - - -def test_llamacpp_float(model): - prompt = ( - "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" - ) - sequence = generate.format(model, float)(prompt, max_tokens=10) - assert isinstance(sequence, float) - - assert sequence != "" - float(sequence) - - -def test_llamacpp_bool(model): - prompt = ( - "<|im_start|>user\nIs this True or False?<|im_end|>\n<|im_start|>assistant\n" - ) - sequence = generate.format(model, bool)(prompt, max_tokens=10) - assert isinstance(sequence, bool) - - assert sequence != "" - bool(sequence) - - -def test_llamacpp_date(model): - prompt = ( - "<|im_start|>user\nWhat day is it today?<|im_end|>\n<|im_start|>assistant\n" - ) - sequence = generate.format(model, datetime.date)(prompt, max_tokens=20, seed=10) - assert isinstance(sequence, datetime.date) - - -def test_llamacpp_time(model): - prompt = "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\n" - sequence = generate.format(model, datetime.time)(prompt, max_tokens=10) - assert isinstance(sequence, datetime.time) - - -def test_llamacpp_datetime(model): - prompt = "<|im_start|>user\nWhat time is it?<|im_end|>\n<|im_start|>assistant\n" - sequence = generate.format(model, datetime.datetime)(prompt, max_tokens=20) - assert isinstance(sequence, datetime.datetime) - - -def test_llamacpp_choice(model): - prompt = ( - "<|im_start|>user\nWrite a short sentence<|im_end|>\n<|im_start|>assistant\n" - ) - sequence = generate.choice(model, ["test", "choice"])(prompt) - assert sequence == "test" or sequence == "choice" - - -def test_llamacpp_json_basic(model): - prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n" - - class Spam(BaseModel): - spam: constr(max_length=10) - fuzz: bool - - result = generate.json(model, Spam, whitespace_pattern="")( - prompt, max_tokens=100, temperature=0.0, seed=1 - ) - assert isinstance(result, BaseModel) - assert isinstance(result.spam, str) - assert isinstance(result.fuzz, bool) - assert len(result.spam) <= 10 - - -def test_llamacpp_json_schema(model): - prompt = "<|im_start|>user\nOutput some JSON<|im_end|>\n<|im_start|>assistant\n" - - schema = """{ - "title": "spam", - "type": "object", - "properties": { - "foo" : {"type": "boolean"}, - "bar": {"type": "string", "maxLength": 4} - }, - "required": ["foo", "bar"] - } - """ - - result = generate.json(model, schema, whitespace_pattern="")( - prompt, max_tokens=100, temperature=0, seed=10 - ) - assert isinstance(result, dict) - assert isinstance(result["foo"], bool) - assert isinstance(result["bar"], str) - - -@pytest.mark.parametrize( - "repo,model_path,hf_tokenizer_uri", - [ - ("Qwen/Qwen1.5-0.5B-Chat-GGUF", "*q2*.gguf", "Qwen/Qwen1.5-0.5B-Chat"), - ("TheBloke/phi-2-GGUF", "*Q2*.gguf", "microsoft/phi-2"), - ], -) -def test_byte_tokenizer_regression(repo, model_path, hf_tokenizer_uri): - """Reproduce https://github.com/dottxt-ai/outlines/issues/820""" - import llama_cpp - - model = models.llamacpp( - repo, - model_path, - tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( - hf_tokenizer_uri - ), - ) - generator = generate.choice(model, ["skirt", "dress", "pen", "jacket"]) - generator("Pick the odd word out: skirt, dress, pen, jacket") - - -def test_llama_cpp_pre_tokenizer_remains_broken(): - """If fails, llama.cpp pre-tokenizer is fixed -> revert #892, remove `with pytest.raises`""" - repo = "Qwen/Qwen1.5-0.5B-Chat-GGUF" - model_path = "*q2*.gguf" - - model = models.llamacpp(repo, model_path) - with pytest.raises(RuntimeError): - generate.choice(model, ["skirt", "dress", "pen", "jacket"]) - - -@pytest.mark.skip("Caching for guide was temporarily turned off") -def test_RegexGuide_caching(model, temp_cache_dir): - import llama_cpp - - import outlines.caching - from outlines.fsm.guide import cached_create_states_mapping - - assert outlines.caching._caching_enabled - - regex = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" - prompt = "What is the IP address of the Google DNS servers? " - - cache = outlines.caching.get_cache() - - # Returns (hits, misses) - _ = cache.stats(enable=True) - assert cache.statistics - - assert cached_create_states_mapping.__memory__ is cache - - generator = generate.regex(model, regex, sampler=samplers.greedy()) - assert cache.stats() == (0, 1) - - model_2 = models.llamacpp( - "Qwen/Qwen1.5-0.5B-Chat-GGUF", - "*q2*.gguf", - tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( - "Qwen/Qwen1.5-0.5B-Chat" - ), - ) - generator_2 = generate.regex(model_2, regex, sampler=samplers.greedy()) - assert cache.stats() == (0, 2) - - # These two different models and tokenizers should not have the same state - # mapping results - assert ( - generator.logits_processor.guide.states_to_token_maps - != generator_2.logits_processor.guide.states_to_token_maps - ) - - generator_3 = generate.regex(model_2, regex, sampler=samplers.greedy()) - assert cache.stats() == (1, 2) - assert ( - generator_2.logits_processor.guide.states_to_token_maps - == generator_3.logits_processor.guide.states_to_token_maps - ) - - # Just for fun... - structured = generator(prompt, max_tokens=30) - structured_2 = generator_2(prompt, max_tokens=30) - - assert re.fullmatch(regex, structured) - assert re.fullmatch(regex, structured_2) - assert structured != structured_2 - - -@pytest.mark.xfail( - reason="Some versions of the Hermes-2-Pro-Llama-3 model have a broken config" -) -def test_tokenizer_vocabulary_decode_sanity(): - """Assert the decoded newline token (198) is the same as the normalized vocab token""" - import llama_cpp - - model = models.llamacpp( - "bartowski/Meta-Llama-3-8B-Instruct-GGUF", - "Meta-Llama-3-8B-Instruct-IQ1_M.gguf", - tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( - "NousResearch/Hermes-2-Pro-Llama-3-8B", - ), - ) - tokenizer = generate.regex(model, "a").logits_processor.tokenizer - - decoded_nl_token = tokenizer.decode([198])[0] - vocab_nl_token = tokenizer.convert_token_to_string( - [token for token, token_id in tokenizer.vocabulary.items() if token_id == 198][ - 0 - ] - ) - assert decoded_nl_token == vocab_nl_token - - -def test_no_length_constraint_when_unset(): - """Assert that models.llamacpp doesn't have an implicit max_tokens preventing full sequence generation""" - import llama_cpp - - model = models.llamacpp( - repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", - filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", - tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained( - "Locutusque/TinyMistral-248M-Instruct" - ), - ) - - long_pattern = "abcdefg" * 10 - generator = generate.regex(model, long_pattern) - - output = generator("a") - assert re.match(long_pattern, output) diff --git a/tests/models/test_llamacpp.py b/tests/models/test_llamacpp.py new file mode 100644 index 000000000..ffb0110cd --- /dev/null +++ b/tests/models/test_llamacpp.py @@ -0,0 +1,118 @@ +import json +from enum import Enum + +import pytest +from pydantic import BaseModel + +from outlines.models import LlamaCpp +from outlines.processors import RegexLogitsProcessor +from outlines.types import Choice, Json, Regex + + +def test_load_model(): + model = LlamaCpp.from_pretrained( + repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", + filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", + ) + + assert isinstance(model, LlamaCpp) + + +@pytest.fixture(scope="session") +def model(tmp_path_factory): + return LlamaCpp.from_pretrained( + repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF", + filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf", + ) + + +def test_llamacpp_simple(model): + result = model.generate("Respond with one word. Not more.", None) + assert isinstance(result, str) + + +def test_llamacpp_regex(model): + regex_str = Regex(r"[0-9]").to_regex() + logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) + result = model.generate("Respond with one word. Not more.", logits_processor) + assert isinstance(result, str) + + +def test_llamacpp_json(model): + class Foo(BaseModel): + bar: str + + regex_str = Json(Foo).to_regex() + logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) + result = model.generate( + "foo? Respond with one word.", logits_processor, max_tokens=1000 + ) + + assert isinstance(result, str) + assert "bar" in json.loads(result) + + +def test_llamacpp_choice(model): + class Foo(Enum): + bar = "Bar" + foor = "Foo" + + regex_str = Choice(Foo).to_regex() + logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) + result = model.generate("foo?", logits_processor) + + assert result == "Foo" or result == "Bar" + + +def test_llamacpp_text_stop(model): + result = model.generate("Write the letter a.", None, stop="a", max_tokens=100) + assert "a" not in result + + +def test_llamacpp_stream_text_stop(model): + generator = model.stream("Write the letter a.", None, stop="a", max_tokens=100) + + result = next(generator) + assert isinstance(result, str) + assert result != "a" + + +def test_llamacpp_stream_simple(model): + generator = model.stream("Respond with one word. Not more.", None) + + x = next(generator) + assert isinstance(x, str) + + +def test_llamacpp_stream_regex(model): + regex_str = Regex(r"[0-9]").to_regex() + logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) + generator = model.stream("Respond with one word. Not more.", logits_processor) + + x = next(generator) + assert isinstance(x, str) + + +def test_llamacpp_stream_json(model): + class Foo(BaseModel): + bar: int + + regex_str = Json(Foo).to_regex() + logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) + generator = model.stream("foo?", logits_processor) + + x = next(generator) + assert x == "{" + + +def test_llamacpp_stream_choice(model): + class Foo(Enum): + bar = "Bar" + foor = "Foo" + + regex_str = Choice(Foo).to_regex() + logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) + generator = model.stream("foo?", logits_processor) + + x = next(generator) + assert isinstance(x, str) From 103b6ae75f980729575c00e6d42a741d4e29080d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 29 Nov 2024 18:46:37 +0100 Subject: [PATCH 02/10] Add `LocalModel` and `APIModel` types --- outlines/models/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index 7bfb09fb0..620108bb9 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -19,3 +19,6 @@ from .vllm import VLLM, vllm LogitsGenerator = Union[Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM] + +LocalModel = LlamaCpp +APIModel = Union[AzureOpenAI, OpenAI, Anthropic, Gemini] From a97db8b9597eb42832cff98b858d216d1010f5d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 29 Nov 2024 20:46:08 +0100 Subject: [PATCH 03/10] Add `Generator` function and builder --- outlines/generate/__init__.py | 71 +++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/outlines/generate/__init__.py b/outlines/generate/__init__.py index f28cbd80d..12041af49 100644 --- a/outlines/generate/__init__.py +++ b/outlines/generate/__init__.py @@ -1,3 +1,9 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union, cast, get_args + +from outlines.models import APIModel, LocalModel +from outlines.types import Choice, Json, List, Regex + from .api import SequenceGenerator from .cfg import cfg from .choice import choice @@ -6,3 +12,68 @@ from .json import json from .regex import regex from .text import text + +if TYPE_CHECKING: + from outlines.processors import RegexLogitsProcessor + + +@dataclass +class APIGenerator: + """Represents an API-based generator. + + Attributes + ---------- + model + An instance of a model wrapper. + output_type + The output type. + + """ + + model: APIModel + output_type: Optional[Union[Json, List, Choice, Regex]] = None + + def __call__(self, prompt, **inference_kwargs): + return self.model.generate(prompt, self.output_type, **inference_kwargs) + + +@dataclass +class LocalGenerator: + """Represents a local model-based generator. + + We use this class to keep track of the logits processor which can be quite + expensive to build. + + Attributes + ---------- + model + An instance of a model wrapper. + output_type + The output type. + + """ + + model: LocalModel + output_type: Optional[Union[Json, List, Choice, Regex]] + + def __post_init__(self): + if self.output_type is None: + self.logits_processor = None + else: + regex_string = self.output_type.to_regex() + self.logits_processor = RegexLogitsProcessor( + regex_string, self.model.tokenizer + ) + + def __call__(self, prompt, **inference_kwargs): + return self.model.generate(prompt, self.logits_processor, **inference_kwargs) + + +def Generator( + model: Union[LocalModel, APIModel], + output_type: Optional[Union[Json, List, Choice, Regex]] = None, +): + if isinstance(model, APIModel): # type: ignore + return APIGenerator(model, output_type) # type: ignore + else: + return LocalGenerator(model, output_type) # type: ignore From af030c12e39641eab097929a8defdc978acc5e75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 29 Nov 2024 18:46:02 +0100 Subject: [PATCH 04/10] Add `to_regex` method to the different types --- outlines/types/__init__.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_types.py | 12 ++++++++++++ 2 files changed, 46 insertions(+) diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 46f49f36c..8a79b8162 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -8,6 +8,8 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import _TypedDictMeta # type: ignore +from outlines.fsm.json_schema import build_regex_from_schema + from . import airports, countries from .email import Email from .isbn import ISBN @@ -30,6 +32,7 @@ class Json: """ definition: Union[str, dict] + whitespace_pattern: str = " " def to_json_schema(self): if isinstance(self.definition, str): @@ -52,11 +55,21 @@ def to_json_schema(self): return schema + def to_regex(self): + schema = self.to_json_schema() + schema_str = json.dumps(schema) + return build_regex_from_schema(schema_str, self.whitespace_pattern) + @dataclass class List: definition: list + def to_regex(self): + raise NotImplementedError( + "Structured generation for lists of objects are not implemented yet." + ) + @dataclass class Choice: @@ -67,3 +80,24 @@ class Choice: def __post_init__(self): if isinstance(self.definition, list): self.definition = Enum("Definition", [(x, x) for x in self.definition]) + + def to_list(self): + if isinstance(self.definition, list): + return self.definition + else: + return [x.value for x in self.definition] + + def to_regex(self): + choices = self.to_list() + regex_str = r"(" + r"|".join(choices) + r")" + return regex_str + + +@dataclass +class Regex: + """Represents a string defined by a regular expression.""" + + definition: str + + def to_regex(self): + return self.definition diff --git a/tests/test_types.py b/tests/test_types.py index d54ac1479..0fcaec140 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -52,6 +52,18 @@ def test_type_choice(): choice_type = types.Choice(choices) assert choice_type.definition.a.value == "a" + regex_str = choice_type.to_regex() + assert regex_str == "(a|b)" + + +def test_type_list(): + class Foo(BaseModel): + bar: int + + list_type = types.List(Foo) + with pytest.raises(NotImplementedError, match="Structured"): + list_type.to_regex() + @pytest.mark.parametrize( "custom_type,test_string,should_match", From aa7e3100b391d7572b4c76d48284667d2c9f6a00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 2 Dec 2024 09:03:46 +0100 Subject: [PATCH 05/10] Support dataclasses to define Json Schema --- outlines/types/__init__.py | 4 +++- tests/test_types.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 8a79b8162..7d0780208 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -1,5 +1,5 @@ import json -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from enum import Enum, EnumMeta from typing import Union @@ -43,6 +43,8 @@ def to_json_schema(self): schema = self.definition.model_json_schema() elif isinstance(self.definition, _TypedDictMeta): schema = TypeAdapter(self.definition).json_schema() + elif is_dataclass(self.definition): + schema = TypeAdapter(self.definition).json_schema() else: raise TypeError( "The Json definition must be a JSON Schema string, dictionary or Pydantic model." diff --git a/tests/test_types.py b/tests/test_types.py index 0fcaec140..4c4627bea 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,5 +1,6 @@ import json import re +from dataclasses import dataclass import pytest from jsonschema.exceptions import SchemaError @@ -46,6 +47,13 @@ class Foo(TypedDict): json_type = types.Json(Foo) assert json_type.to_json_schema() == json_schema_dict + @dataclass + class Foo: + bar: int + + json_type = types.Json(Foo) + assert json_type.to_json_schema() == json_schema_dict + def test_type_choice(): choices = ["a", "b"] From b5a02a1a89f2ac22efddcf88637bc6328169cd06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 2 Dec 2024 11:23:58 +0100 Subject: [PATCH 06/10] Update the documentation of the llama.cpp integration --- docs/reference/models/llamacpp.md | 100 ++++++------------------------ 1 file changed, 19 insertions(+), 81 deletions(-) diff --git a/docs/reference/models/llamacpp.md b/docs/reference/models/llamacpp.md index 51b62eca8..52e892b76 100644 --- a/docs/reference/models/llamacpp.md +++ b/docs/reference/models/llamacpp.md @@ -12,50 +12,38 @@ Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/l ## Load the model -You can initialize the model by passing the name of the repository on the HuggingFace Hub, and the filenames (or glob pattern): +To load a model you can use the same interface as you would using `llamap-cpp-python` directly. The default method is to initialize the model by passing the path to the weights on your machine. Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the current directory: ```python from outlines import models -model = models.llamacpp("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") +llm = models.LlamaCpp("./phi-2.Q4_K_M.gguf") ``` -This will download the model files to the hub cache folder and load the weights in memory. +You can initialize the model by passing the name of the repository on the HuggingFace Hub, and the filenames (or glob pattern): -You can also initialize the model by passing the path to the weights on your machine. Assuming [Phi2's weights](https://huggingface.co/TheBloke/phi-2-GGUF) are in the current directory: ```python from outlines import models -from llama_cpp import Llama -llm = Llama("./phi-2.Q4_K_M.gguf") -model = models.LlamaCpp(llm) +model = models.LlamaCpp.from_pretrained("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") ``` -If you need more control, you can pass the same keyword arguments to the model as you would pass in the [llama-ccp-library][llamacpp]: +This will download the model files to the hub cache folder and load the weights in memory. + + +You can pass the same keyword arguments to the model as you would pass in the [llama-ccp-library][llamacpp]: ```python from outlines import models -model = models.llamacpp( +model = models.LlamaCpp( "TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf" n_ctx=512, # to set the context length value ) ``` -**Main parameters:** - -| Parameters | Type | Description | Default | -|------------|------|-------------|---------| -| `n_gpu_layers`| `int` | Number of layers to offload to GPU. If -1, all layers are offloaded | `0` | -| `split_mode` | `int` | How to split the model across GPUs. `1` for layer-wise split, `2` for row-wise split | `1` | -| `main_gpu` | `int` | Main GPU | `0` | -| `tensor_split` | `Optional[List[float]]` | How split tensors should be distributed across GPUs. If `None` the model is not split. | `None` | -| `n_ctx` | `int` | Text context. Inference from the model if set to `0` | `0` | -| `n_threads` | `Optional[int]` | Number of threads to use for generation. All available threads if set to `None`.| `None` | -| `verbose` | `bool` | Print verbose outputs to `stderr` | `False` | - See the [llama-cpp-python documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__) for the full list of parameters. ### Load the model on GPU @@ -69,87 +57,39 @@ See the [llama-cpp-python documentation](https://llama-cpp-python.readthedocs.io ```python from outlines import models -model = models.llamacpp( +model = models.LlamaCpp( "TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", n_gpu_layers=-1, # to use GPU acceleration ) ``` -This also works with generators built with `generate.regex`, `generate.json`, `generate.cfg`, `generate.format` and `generate.choice`. -### Load LoRA adapters +## Generate text + -You can load LoRA adapters dynamically: +To generate text you must first create a `Generator` object by passing the model instance and, possibley, the expected output type: ```python from outlines import models, generate -model = models.llamacpp("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") -generator = generate.text(model) -answer_1 = generator("prompt") -model.load_lora("./path/to/adapter.gguf") -answer_2 = generator("prompt") +model = models.LlamaCpp("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") +generator = Generator(model) ``` -To load another adapter you need to re-initialize the model. Otherwise the adapter will be added on top of the previous one: +You can pass to the generator the same keyword arguments you would pass in `llama-cpp-python`: ```python -from outlines import models - -model = models.llamacpp("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") -model.load_lora("./path/to/adapter1.gguf") # Load first adapter - -model = models.llamacpp("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") -model.load_lora("./path/to/adapter2.gguf") # Load second adapter +answer = generator("A prompt", presence_penalty=0.8) ``` -## Generate text - -In addition to the parameters described in the [text generation section](../text.md) you can pass extra keyword arguments, for instance to set sampling parameters not exposed in Outlines' public API: +You can also stream the tokens: ```python -from outlines import models, generate - - -model = models.llamacpp("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf") -generator = generate.text(model) - -answer = generator("A prompt", presence_penalty=0.8) +tokens = generator.stream("A prompt") ``` -**Extra keyword arguments:** - -The value of the keyword arguments you pass to the generator suspersede the values set when initializing the sampler or generator. All extra sampling methods and repetition penalties are disabled by default. - -| Parameters | Type | Description | Default | -|------------|------|-------------|---------| -| `suffix` | `Optional[str]` | A suffix to append to the generated text. If `None` no suffix is added. | `None` | -| `echo` | `bool` | Whether to preprend the prompt to the completion. | `False` | -| `seed` | `int` | The random seed to use for sampling. | `None` | -| `max_tokens` | `Optional[int]` | The maximum number of tokens to generate. If `None` the maximum number of tokens depends on `n_ctx`. | `16` | -| `frequence_penalty` | `float` | The penalty to apply to tokens based on their frequency in the past 64 tokens. | `0.0` | -| `presence_penalty` | `float` | The penalty to apply to tokens based on their presence in the past 64 tokens. | `0.0` | -| `repeat_penalty` | `float` | The penalty to apply to repeated tokens in the past 64 tokens. | `1.` | -| `stopping_criteria` | `Optional[StoppingCriteriaList]` | A list of stopping criteria to use. | `None` -| `logits_processor` | `Optional[LogitsProcessorList]` | A list of logits processors to use. The logits processor used for structured generation will be added to this list. | `None` -| `temperature` | `float` | The temperature to use for sampling | `1.0` | -| `top_p` | `float` | The top-p value to use for [nucleus sampling][degeneration]. | `1.` | -| `min_p` | `float` | The min-p value to use for [minimum-p sampling][minimum-p]. | `0.` | -| `typical_p` | `float` | The p value to use for [locally typical sampling][locally-typical]. | `1.0` | -| `stop` | `Optional[Union[str, List[str]]]` | A list of strings that stop generation when encountered. | `[]` | -| `top_k` | `int` | The top-k value used for [top-k sampling][top-k]. Negative value to consider all logit values. | `-1.` | -| `tfs_z` | `float` | The [tail-free sampling][tail-free] parameter. | `1.0` | -| `mirostat_mode` | `int` | The [mirostat sampling][mirostat] mode. | `0` | -| `mirostat_tau` | `float` | The target cross-entropy for [mirostat sampling][mirostat].| `5.0` | -| `mirostat_eta` | `float` | The learning rate used to update `mu` in [mirostat sampling][mirostat]. | `0.1` | - -See the [llama-cpp-python documentation][llama-cpp-python-call] for the full and up-to-date list of parameters and the [llama.cpp code][llama-cpp-sampling-params] for the default values of other -sampling parameters. - -### Streaming - ## Installation @@ -216,8 +156,6 @@ CMAKE_ARGS="-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS" pip install llama-cpp- - SYCL - - [llamacpp]: https://github.com/abetlen/llama-cpp-python [llama-cpp-python-call]: https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__call__ [llama-cpp-python-install]: https://github.com/abetlen/llama-cpp-python/tree/08b16afe11e7b42adec2fed0a781123383476045?tab=readme-ov-file#supported-backends From 2436f83ccd43ff77f6d9c0f63dbb7f3a6314bd79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 2 Dec 2024 11:24:26 +0100 Subject: [PATCH 07/10] Update the documentation of the OpenAI integration --- docs/reference/models/openai.md | 172 ++++++++++---------------------- 1 file changed, 51 insertions(+), 121 deletions(-) diff --git a/docs/reference/models/openai.md b/docs/reference/models/openai.md index 638107568..3297f32de 100644 --- a/docs/reference/models/openai.md +++ b/docs/reference/models/openai.md @@ -2,7 +2,7 @@ !!! Installation - You need to install the `openai` library to be able to use the OpenAI API in Outlines. Or alternatively: + You need to install the `openai` library to be able to use the OpenAI API in Outlines. Or alternatively you can run: ```bash pip install "outlines[openai]" @@ -10,44 +10,30 @@ ## OpenAI models -Outlines supports models available via the OpenAI Chat API, e.g. GPT-4o, ChatGPT and GPT-4. You can initialize the model by passing the model name to `outlines.models.openai`: +Outlines supports models available via the OpenAI Chat API, e.g. GPT-4o, ChatGPT and GPT-4. You can initialize the model by passing the model name to `outlines.models.OpenAI`: ```python from outlines import models -model = models.openai("gpt-4o-mini") -model = models.openai("gpt-4o") +model = models.OpenAI("gpt-4o-mini") +model = models.OpenAI("gpt-4o") ``` -Check the [OpenAI documentation](https://platform.openai.com/docs/models/gpt-4o) for an up-to-date list of available models. You can pass any parameter you would pass to `openai.AsyncOpenAI` as keyword arguments: +Check the [OpenAI documentation](https://platform.openai.com/docs/models/gpt-4o) for an up-to-date list of available models. You can pass any parameter you would pass to `openai.OpenAI` as keyword arguments: ```python import os from outlines import models -model = models.openai( +model = models.OpenAI( "gpt-4o-mini", api_key=os.environ["OPENAI_API_KEY"] ) ``` -The following table enumerates the possible parameters. Refer to the [OpenAI SDK's code](https://github.com/openai/openai-python/blob/54a5911f5215148a0bdeb10e2bcfb84f635a75b9/src/openai/_client.py) for an up-to-date list. - -**Parameters:** - -| **Parameters** | **Type** | **Description** | **Default** | -|----------------|:---------|:----------------|:------------| -| `api_key` | `str` | OpenAI API key. Infered from `OPENAI_API_KEY` if not specified | `None` | -| `organization` | `str` | OpenAI organization id. Infered from `OPENAI_ORG_ID` if not specified | `None` | -| `project` | `str` | OpenAI project id. Infered from `OPENAI_PROJECT_ID` if not specified.| `None` | -| `base_url` | `str | https.URL` | Base URL for the endpoint. Infered from `OPENAI_BASE_URL` if no specified. | `None` | -| `timeout` | `float` | Request timeout.| `NOT_GIVEN` | -| `max_retries` | `int` | Maximum number of retries for failing requests | `2` | -| `default_headers` | `Mapping[str, str]` | Default HTTP headers | `None` | -| `default_query` | `Mapping[str, str]` | Custom parameters added to the HTTP queries | `None` | -| `http_client` | `https.AsyncClient` | User-specified `httpx` client | `None` | +Refer to the [OpenAI SDK's code](https://github.com/openai/openai-python/blob/54a5911f5215148a0bdeb10e2bcfb84f635a75b9/src/openai/_client.py) for an up-to-date list of the initialization parameters. ## Azure OpenAI models @@ -57,93 +43,15 @@ Outlines also supports Azure OpenAI models: from outlines import models -model = models.azure_openai( +model = models.AzureOpenAI( "azure-deployment-name", - "gpt-4o-mini", api_version="2024-07-18", azure_endpoint="https://example-endpoint.openai.azure.com", ) ``` -!!! Question "Why do I need to specify model and deployment name?" - - The model name is needed to load the correct tokenizer for the model. The tokenizer is necessary for structured generation. - - -You can pass any parameter you would pass to `openai.AsyncAzureOpenAI`. You can consult the [OpenAI SDK's code](https://github.com/openai/openai-python/blob/54a5911f5215148a0bdeb10e2bcfb84f635a75b9/src/openai/lib/azure.py) for an up-to-date list. - -**Parameters:** - - -| **Parameters** | **Type** | **Description** | **Default** | -|----------------|:---------|:----------------|:------------| -| `azure_endpoint` | `str` | Azure endpoint, including the resource. Infered from `AZURE_OPENAI_ENDPOINT` if not specified | `None` | -| `api_version` | `str` | API version. Infered from `AZURE_OPENAI_API_KEY` if not specified | `None` | -| `api_key` | `str` | OpenAI API key. Infered from `OPENAI_API_KEY` if not specified | `None` | -| `azure_ad_token` | `str` | Azure active directory token. Inference from `AZURE_OPENAI_AD_TOKEN` if not specified | `None` | -| `azure_ad_token_provider` | `AzureADTokenProvider` | A function that returns an Azure Active Directory token | `None` | -| `organization` | `str` | OpenAI organization id. Infered from `OPENAI_ORG_ID` if not specified | `None` | -| `project` | `str` | OpenAI project id. Infered from `OPENAI_PROJECT_ID` if not specified.| `None` | -| `base_url` | `str | https.URL` | Base URL for the endpoint. Infered from `OPENAI_BASE_URL` if not specified. | `None` | -| `timeout` | `float` | Request timeout.| `NOT_GIVEN` | -| `max_retries` | `int` | Maximum number of retries for failing requests | `2` | -| `default_headers` | `Mapping[str, str]` | Default HTTP headers | `None` | -| `default_query` | `Mapping[str, str]` | Custom parameters added to the HTTP queries | `None` | -| `http_client` | `https.AsyncClient` | User-specified `httpx` client | `None` | - -## Models that follow the OpenAI standard - -Outlines supports models that follow the OpenAI standard. You will need to initialize the OpenAI client properly configured and pass it to `outlines.models.openai` - -```python -import os -from openai import AsyncOpenAI -from outlines import models -from outlines.models.openai import OpenAIConfig - - -client = AsyncOpenAI( - api_key=os.environ.get("PROVIDER_KEY"), - base_url="http://other.provider.server.com" -) -config = OpenAIConfig("model_name") -model = models.openai(client, config) -``` - -!!! Warning - - You need to pass the async client to be able to do batch inference. - -## Structured Generation Support - -Outlines provides support for [OpenAI Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs/json-mode) via `outlines.generate.json`, `outlines.generate.choice` - -```python -from pydantic import BaseModel, ConfigDict -import outlines.models as models -from outlines import generate - -model = models.openai("gpt-4o-mini") - -class Person(BaseModel): - model_config = ConfigDict(extra='forbid') # required for openai - first_name: str - last_name: str - age: int - -generate.json(model, Person) -generator("current indian prime minister on january 1st 2023") -# Person(first_name='Narendra', last_name='Modi', age=72) - -generator = generate.choice(model, ["Chicken", "Egg"]) -print(generator("Which came first?")) -# Chicken -``` - -!!! Warning - - Structured generation support only provided to OpenAI-compatible endpoints which conform to OpenAI's standard. Additionally, `generate.regex` and `generate.cfg` are not supported. +You can pass any parameter you would pass to `openai.AzureOpenAI`. You can consult the [OpenAI SDK's code](https://github.com/openai/openai-python/blob/54a5911f5215148a0bdeb10e2bcfb84f635a75b9/src/openai/lib/azure.py) for an up-to-date list. ## Advanced configuration @@ -163,42 +71,64 @@ client = AsyncOpenAI( transport=httpx.HTTPTransport(local_address="0.0.0.0"), ), ) -config = OpenAIConfig("model_name") -model = models.openai(client, config) ``` -It is possible to specify the values for `seed`, `presence_penalty`, `frequence_penalty`, `top_p` by passing an instance of `OpenAIConfig` when initializing the model: +## Models that follow the OpenAI standard + +Outlines supports models that follow the OpenAI standard. You will need to initialize the OpenAI client properly configured and pass it to `outlines.models.OpenAI` ```python -from outlines.models.openai import OpenAIConfig +import os +from openai import AsyncOpenAI from outlines import models +from outlines.models.openai import OpenAIConfig -config = OpenAIConfig( - presence_penalty=1., - frequency_penalty=1., - top_p=.95, - seed=0, +model = models.OpenAI( + "model_name", + api_key=os.environ.get("PROVIDER_KEY"), + base_url="http://other.provider.server.com" ) -model = models.openai("gpt-4o-mini", config) ``` -## Monitoring API use +## Text generation -It is important to be able to track your API usage when working with OpenAI's API. The number of prompt tokens and completion tokens is directly accessible via the model instance: +To generate text using an OpenAI model you need to build a `Generator` object, possibly with the desired output type. You can then call the model by calling the `Generator`. The method accepts every argument that you could pass to the `client.completions.create` function, as keyword arguments: ```python -from openai import AsyncOpenAI -import outlines.models +from outlines import models, Generator + +model = models.OpenAI("gpt-4o-mini") +generator = Generator(model) +result = generator("Prompt", seed=10) +``` + +See the [OpenAI SDK documentation](https://github.com/openai/openai-python/blob/6974a981aec1814b5abba429a8ea21be9ac58538/src/openai/types/completion_create_params.py#L13) for the list of available arguments. + +### Structured Generation Support +Outlines provides support for [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs/json-mode). Currently only JSON-Schema is supported: -model = models.openai("gpt-4o") +```python +from pydantic import BaseModel +from outlines import models, Generator +from outlines.types import Json + +model = models.OpenAI("gpt-4o-mini") -print(model.prompt_tokens) -# 0 +class Person(BaseModel): + first_name: str + last_name: str + age: int -print(model.completion_tokens) -# 0 +generator = Generator(model, Json(Person)) +generator("current indian prime minister on january 1st 2023") +# Person(first_name='Narendra', last_name='Modi', age=72) ``` -These numbers are updated every time you call the model. +The following objects can be used to define the structure of the Json object: +- A string that represents a Json Schema +- A dictionary that represents a Json Schema +- A Pydantic model +- A TypedDict +- A dataclass From e57995dd198a64c19d2afc2ad4bda0702f978f24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 2 Dec 2024 11:25:31 +0100 Subject: [PATCH 08/10] Update the documentation for the Gemini integration --- docs/reference/models/gemini.md | 88 +++++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 89 insertions(+) create mode 100644 docs/reference/models/gemini.md diff --git a/docs/reference/models/gemini.md b/docs/reference/models/gemini.md new file mode 100644 index 000000000..07ec9aa5a --- /dev/null +++ b/docs/reference/models/gemini.md @@ -0,0 +1,88 @@ +# Gemini + +!!! Installation + + You need to install the `google-generativeai` library to be able to use the Gemini API in Outlines. Or alternatively you can run: + + ```bash + pip install "outlines[gemini]" + ``` + +## Gemini models + +Outlines supports models available via the Gemini API, e.g. Gemini 1.5. You can initialize the model by passing the model name to `outlines.models.Gemini`: + +```python +from outlines import models + +model = models.Gemini("gemini-1-5-flash") +model = models.Gemini("gemini-1-5-pro") +``` + +Check the [Gemini documentation](https://ai.google.dev/gemini-api/docs/models/gemini) for an up-to-date list of available models. + +## Text generation + +To generate text using a Gemini model you need to build a `Generator` object, possibly with the desired output type. You can then call the model by calling the `Generator`. The method accepts every argument that you could pass to the `client.completions.create` function, as keyword arguments: + +```python +from outlines import models, Generator + +model = models.Gemini("gemini-1-5-flash") +generator = Generator(model) +result = generator("Prompt", max_tokens=1024) +``` + +### Structured generation + +Gemini provides support for structured outputs. + +#### Json Schema + +Outlines provides support for JSON Schema-based structured generation with the Gemini models: + +```python +from collections import TypedDict +from outlines import Generator, models +from outlines.types import Json + +model = models.Gemini("gemini-1-5-flash") + +class Person(TypedDict): + first_name: str + last_name: str + age: int + +generator = Generator(model, Json(Person)) +generator("current indian prime minister on january 1st 2023") +# Person(first_name='Narendra', last_name='Modi', age=72) +``` + +Because of the current limitations of the Gemini SDK only The following objects can be used to define the structure of the Json object: +- A Pydantic model +- A TypedDict + +#### Multiple choices + +Outlines provides support for multiple-choices structured generation. Enums and lists of choices are supported: + +```python +from enum import Enum +from outlines import Generator, models +from outlines.types import Choice + +model = models.Gemini("gemini-1-5-flash") + +class Foo(Enum): + foo = "Foo" + fizz = "Fizz" + fuzz = "Fuzz" + +generator = Generator(model, Choice(Foo)) +generator("current indian prime minister on january 1st 2023") +# Person(first_name='Narendra', last_name='Modi', age=72) +``` + +The following objects can be used to define the choices: +- An Enum object +- A Python list diff --git a/pyproject.toml b/pyproject.toml index d54101265..88d6b7e12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ vllm = ["vllm", "transformers", "numpy2"] transformers = ["transformers", "accelerate", "datasets", "numpy<2"] mlxlm = ["mlx-lm", "datasets"] openai = ["openai"] +gemini = ["google-generativeai"] llamacpp = ["llama-cpp-python", "transformers", "datasets", "numpy<2"] exllamav2 = ["exllamav2"] test = [ From d6d56dc26ed7a44917fb4d2627d1c651c28bfb90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 2 Dec 2024 11:26:01 +0100 Subject: [PATCH 09/10] Update the documentation for the Anthropic integration --- docs/reference/models/anthropic.md | 45 ++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 46 insertions(+) create mode 100644 docs/reference/models/anthropic.md diff --git a/docs/reference/models/anthropic.md b/docs/reference/models/anthropic.md new file mode 100644 index 000000000..ffc510f43 --- /dev/null +++ b/docs/reference/models/anthropic.md @@ -0,0 +1,45 @@ +# Anthropic + +!!! Installation + + You need to install the `anthropic` library to be able to use the Anthropic API in Outlines. Or alternatively you can run: + + ```bash + pip install "outlines[anthropic]" + ``` + +## Anthropic models + +Outlines supports models available via the Anthropic API, e.g. Claude 3.5 Haiku or Claude 3.5 Sonner. You can initialize the model by passing the model name to `outlines.models.Anthropic`: + +```python +from outlines import models + +model = models.Anthropic("claude-3-5-haiku-latest") +model = models.Anthropic("claude-3-5-sonnet-latest") +``` + +Check the [Anthropic documentation](https://docs.anthropic.com/en/docs/about-claude/models) for an up-to-date list of available models. You can pass any paramater you would pass to the Anthropic SDK as keyword arguments: + +```python +model = models.Anthropic( + "claude-3.5-haiku-latest", + api_key="" +) +``` + +## Text generation + +To generate text using an Anthropic model you need to build a `Generator` object, possibly with the desired output type. You can then call the model by calling the `Generator`. The method accepts every argument that you could pass to the `client.completions.create` function, as keyword arguments: + +```python +from outlines import models, Generator + +model = models.Anthropic("claude-3-5-haiku-latest") +generator = Generator(model) +result = generator("Prompt", max_tokens=1024) +``` + +See the [Anthropic SDK documentation](https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/resources/messages.py) for the list of available arguments. + +The Anthropic API currently does not support structured generation. diff --git a/pyproject.toml b/pyproject.toml index 88d6b7e12..a1650d58f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ vllm = ["vllm", "transformers", "numpy2"] transformers = ["transformers", "accelerate", "datasets", "numpy<2"] mlxlm = ["mlx-lm", "datasets"] openai = ["openai"] +anthropic = ["anthropic"] gemini = ["google-generativeai"] llamacpp = ["llama-cpp-python", "transformers", "datasets", "numpy<2"] exllamav2 = ["exllamav2"] From 98c1e30d3a109b4857b562adc00d8a0251ef2d2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 16 Dec 2024 13:41:42 +0100 Subject: [PATCH 10/10] Run tests when PR against `v1.0` branch --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 36e6f8526..558002745 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,7 +2,7 @@ name: Tests on: pull_request: - branches: [main] + branches: [main,v1.0] push: branches: [main]