Skip to content

Commit

Permalink
Clear cache upon version upgrade (#566)
Browse files Browse the repository at this point in the history
Fixes #561

Co-authored-by: Andrew Lapp <[email protected]>
  • Loading branch information
lapp0 and Andrew Lapp authored Jan 23, 2024
1 parent 32047ab commit 8a0bafc
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 18 deletions.
48 changes: 30 additions & 18 deletions outlines/caching.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,41 @@
import asyncio
import functools
import hashlib
import os
from typing import Callable, Optional

import cloudpickle
from diskcache import Cache

home_dir = os.path.expanduser("~")
cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines")
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
_caching_enabled = True


@functools.lru_cache(1)
def get_cache():
"""Get the context object that contains previously-computed return values.
The cache is used to avoid unnecessary computations and API calls, which can
be long and expensive for large models.
The cache directory defaults to `HOMEDIR/.cache/outlines`, but this choice
can be overridden by the user by setting the value of the `OUTLINES_CACHE_DIR`
environment variable.
"""
from outlines._version import __version__ as outlines_version # type: ignore

home_dir = os.path.expanduser("~")
cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines")
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)

# ensure if version upgrade occurs, old cache is pruned
if outlines_version != memory.get("__version__"):
memory.clear()
memory["__version__"] = outlines_version

return memory


def hash_arguments(*args, **kwargs) -> str:
"""Create a hash out of the args and kwargs provided"""
result = hashlib.md5()
Expand All @@ -35,6 +59,8 @@ def cache(key_function: Optional[Callable] = None):
"""

def decorator(cached_function: Callable):
memory = get_cache()

def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)
Expand Down Expand Up @@ -71,20 +97,6 @@ async def async_wrapper(*args, **kwargs):
return decorator


def get_cache():
"""Get the context object that contains previously-computed return values.
The cache is used to avoid unnecessary computations and API calls, which can
be long and expensive for large models.
The cache directory defaults to `HOMEDIR/.cache/outlines`, but this choice
can be overridden by the user by setting the value of the `OUTLINES_CACHE_DIR`
environment variable.
"""
return memory


def disable_cache():
"""Disable the cache for this session.
Expand All @@ -111,5 +123,5 @@ def disable_cache():

def clear_cache():
"""Erase the cache completely."""
global memory
memory = get_cache()
memory.clear()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ test = [
"pre-commit",
"pytest",
"pytest-cov",
"pytest-mock",
"transformers",
"coverage[toml]>=5.1",
"diff-cover",
Expand Down
45 changes: 45 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,48 @@ def f(x):
outlines.clear_cache()
f(1)
assert len(store) == store_size + 1


def test_version_upgrade_cache_invalidate(test_cache, mocker):
"""Ensure we can change the signature of a cached function if we upgrade the version"""

import outlines.caching

def simulate_restart_outlines():
# clearing in-memory lru_cache which returns the diskcache in
# order to simulate a reload, we're not clearing the diskcache itself
outlines.caching.get_cache.cache_clear()

mocker.patch("outlines._version.__version__", new="0.0.0")
simulate_restart_outlines()

# initialize cache with signature of Tuple-of-3
@test_cache
def foo():
return (1, 2, 3)

a, b, c = foo()

# "restart" outlines without upgrading version
simulate_restart_outlines()

# change signature to Tuple-of-2
@test_cache
def foo():
return (1, 2)

# assert without version upgrade, old, bad cache is used
with pytest.raises(ValueError):
a, b = foo()

# "restart" outlines WITH version upgrade
mocker.patch("outlines._version.__version__", new="0.0.1")
simulate_restart_outlines()

# change signature to Tuple-of-2
@test_cache
def foo():
return (1, 2)

# assert with version upgrade, old cache is invalidated and new cache is used
a, b = foo()

0 comments on commit 8a0bafc

Please sign in to comment.