Skip to content

Commit

Permalink
update aio for cache
Browse files Browse the repository at this point in the history
  • Loading branch information
luyaxi committed Sep 18, 2024
1 parent f339418 commit 4bb5c37
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 58 deletions.
2 changes: 1 addition & 1 deletion codelinker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class Config:
max_retry_times: int = 3

class RequestConfig(BaseModel):
lib: str = 'openai'
format: Literal[
"chat",
"tool_call",
Expand All @@ -35,6 +34,7 @@ class RequestConfig(BaseModel):
schema_validation: bool = True
dynamic_json_fix: bool = True

default_request_lib: str = 'openai'
default_completions_model: str = "gpt-3.5-turbo-16k"
default_embeding_model: str = "text-embedding-ada-002"
default_timeout: int = 600
Expand Down
9 changes: 7 additions & 2 deletions codelinker/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ async def request(
messages: list = [],
reasoning_format: StructureSchema = None,
schema_validation: bool = None,
dynamic_json_fix: bool = None,):
dynamic_json_fix: bool = None,
):
resolve_none_object = False
schema = return_type.json_schema()
if schema == TypeAdapter(str).json_schema():
Expand Down Expand Up @@ -179,9 +180,13 @@ async def exec(
messages: list = [],
reasoning_format: StructureSchema = None,
schema_validation: bool = None,
dynamic_json_fix: bool = None) -> T:
dynamic_json_fix: bool = None,
request_lib: str = None) -> T:
if model:
completions_kwargs["model"] = model
if request_lib:
completions_kwargs["request_lib"] = request_lib

return await request(
prompt=prompt,
return_type=TypeAdapter(return_type),
Expand Down
42 changes: 27 additions & 15 deletions codelinker/request/objGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import jsonschema
import jsonschema.exceptions
import importlib
import aiofiles
import asyncio

from typing import Literal
from copy import deepcopy
Expand All @@ -22,47 +24,57 @@ def __init__(self, config: CodeLinkerConfig, logger: Logger):
self.logger = logger
self.chatcompletion_request_funcs = {}

if config.request.use_cache:
loop = asyncio.get_event_loop()
loop.run_until_complete(self._load_cache_files())
if self.config.request.save_completions:
os.makedirs(self.config.request.save_completions_path, exist_ok=True)

async def _load_cache_files(self):
if self.config.request.use_cache:
self.logger.warning("use_cache is enabled, loading completions from cache...")
self.hash2files = {}
files = glob.glob(os.path.join(config.request.save_completions_path, "*.json"))
files = glob.glob(os.path.join(self.config.request.save_completions_path, "*.json"))
files.sort(key=os.path.getmtime)
for file in files:
with open(file, 'r') as f:
data = json.load(f)

async def load_file(file):
async with aiofiles.open(file, 'r') as f:
data = json.loads(await f.read())
self.hash2files[hash(json.dumps(data["request"],sort_keys=True))] = file
tasks = [load_file(file) for file in files]
await asyncio.gather(*tasks)
self.logger.warning("Cache loaded and enabled, which may cause unexpected behavior.")

async def _chatcompletion_request(self, *, request_lib: Literal["openai",] = None, **kwargs) -> dict:
if self.config.request.use_cache:
hash_ = hash(json.dumps(kwargs,sort_keys=True))
self.logger.debug(f"Request hash: {hash_}")
if hash_ in self.hash2files:
with open(self.hash2files[hash_], 'r') as f:
data = json.load(f)
async with aiofiles.open(self.hash2files[hash_], 'r') as f:
data = json.loads(await f.read())
if data["request"] == kwargs:
self.logger.info(f"Cache hit file {self.hash2files[hash_]}, return cached completions.")
# remove cache to avoid duplicate
self.hash2files.pop(hash_)
return data["response"]


request_lib = request_lib if request_lib is not None else self.config.request.lib
request_lib = request_lib if request_lib is not None else self.config.request.default_request_lib
response = await self._get_chatcompletion_request_func(request_lib)(config=self.config,**kwargs)


if self.config.request.save_completions:
os.makedirs(self.config.request.save_completions_path, exist_ok=True)
with open(os.path.join(self.config.request.save_completions_path,
async with aiofiles.open(os.path.join(self.config.request.save_completions_path,
datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")+f"{uuid.uuid4().hex}.json"
), 'w') as f:
json.dump({
data = json.dumps({
"request": kwargs,
"response": response
}, f, indent=4)

},)
await f.write(data)

return response

def register_request_lib(self, request_type: str, request_func):
self.chatcompletion_request_funcs[request_type] = request_func

def _get_chatcompletion_request_func(self, request_type: str):
if request_type not in self.chatcompletion_request_funcs:
Expand Down
91 changes: 51 additions & 40 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tenacity = ">=8.2.3,<9.0"
jsonschema = ">=4.21.1,<5.0"
toml = ">=0.10.2,<1.0"
tiktoken = ">=0.7.0,<1.0"
aiofiles = ">=24.1.0"


[build-system]
Expand Down

0 comments on commit 4bb5c37

Please sign in to comment.