Skip to content

Commit

Permalink
use sqlite to store session history and enable reproduction in k-means (
Browse files Browse the repository at this point in the history
  • Loading branch information
peteryang1 authored Aug 20, 2024
1 parent 48ff804 commit 8256067
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
44 changes: 25 additions & 19 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ def __init__(self, cache_location: str) -> None:
)
""",
)
self.c.execute(
"""
CREATE TABLE message_cache (
conversation_id TEXT PRIMARY KEY,
message TEXT
)
""",
)
self.conn.commit()

def chat_get(self, key: str) -> str | None:
Expand Down Expand Up @@ -144,33 +152,31 @@ def embedding_set(self, content_to_embedding_dict: dict) -> None:
)
self.conn.commit()

def message_get(self, conversation_id: str) -> list[str]:
self.c.execute("SELECT message FROM message_cache WHERE conversation_id=?", (conversation_id,))
result = self.c.fetchone()
if result is None:
return []
return json.loads(result[0])

def message_set(self, conversation_id: str, message_value: list[str]) -> None:
self.c.execute(
"INSERT OR REPLACE INTO message_cache (conversation_id, message) VALUES (?, ?)",
(conversation_id, json.dumps(message_value)),
)
self.conn.commit()


class SessionChatHistoryCache(SingletonBaseClass):
def __init__(self) -> None:
"""load all history conversation json file from self.session_cache_location"""
self.cfg = RD_AGENT_SETTINGS
self.session_cache_location = Path(self.cfg.session_cache_folder_location)
self.cache = {}
if not self.session_cache_location.exists():
logger.warning(f"Directory {self.session_cache_location} does not exist.")
self.session_cache_location.mkdir(parents=True, exist_ok=True)
json_files = [f for f in self.session_cache_location.iterdir() if f.suffix == ".json"]
for file_path in json_files:
conversation_id = file_path.stem
with file_path.open("r") as f:
conversation_content = json.load(f)
self.cache[conversation_id] = conversation_content["content"]
self.cache = SQliteLazyCache(cache_location=RD_AGENT_SETTINGS.prompt_cache_path)

def message_get(self, conversation_id: str) -> list[str]:
return self.cache.get(conversation_id, [])
return self.cache.message_get(conversation_id)

def message_set(self, conversation_id: str, message_value: list[str]) -> None:
self.cache[conversation_id] = message_value
conversation_path = self.session_cache_location / conversation_id
conversation_path = conversation_path.with_suffix(".json")
current_time = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d-%H-%M-%S")
with conversation_path.open("w") as f:
json.dump({"content": message_value, "last_modified_time": current_time}, f)
self.cache.message_set(conversation_id, message_value)


class ChatSession:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def find_closest_cluster_cosine_similarity(
return np.argmax(similarity, axis=1)

# Initializes the cluster center
rng = np.random.default_rng()
rng = np.random.default_rng(seed=42)
centroids = rng.choice(x_normalized, size=k, replace=False)

# Iterate until convergence or the maximum number of iterations is reached
Expand Down

0 comments on commit 8256067

Please sign in to comment.