diff --git a/outlines/caching.py b/outlines/caching.py index ecaa950cc..1bb66936f 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -1,5 +1,5 @@ import os -from typing import Callable +from typing import Callable, Optional from perscache import Cache, NoCache from perscache.serializers import JSONSerializer @@ -10,8 +10,11 @@ memory = Cache(serializer=JSONSerializer(), storage=LocalFileStorage(cache_dir)) -def cache(fn: Callable): - return memory.cache()(fn) +def cache(ignore: Optional[str]): + def cache_fn(fn: Callable): + return memory.cache(ignore=ignore)(fn) + + return cache_fn def get_cache(): diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 8c6ea0647..887b87129 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -9,6 +9,7 @@ import numpy as np import outlines +from outlines.caching import cache __all__ = ["OpenAI", "openai"] @@ -287,6 +288,7 @@ def __repr__(self): return str(self.config) +@cache(ignore="client") @functools.partial(outlines.vectorize, signature="(),(),()->(s)") async def generate_chat( prompt: str, client: "AsyncOpenAI", config: OpenAIConfig