Skip to content

Commit

Permalink
cache .thread
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Oct 8, 2024
1 parent 2d2ef10 commit baf8e54
Showing 1 changed file with 29 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
init=False,
)
_thread: Optional[dict] = field(default=None, init=False)

@api_key.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
Expand All @@ -59,30 +60,33 @@ def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
@property
def thread(self) -> dict:
"""Try to get the Thread by ID, alias, or create a new one."""
thread = None
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID")

if self.thread_id is not None:
res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False)
if res.status_code == 200:
thread = res.json()

# use name as 'alias' to get thread
if thread is None and self.alias is not None:
res = self._call_api("get", f"/threads?alias={self.alias}").json()
if res.get("threads"):
thread = res["threads"][0]
self.thread_id = thread.get("thread_id")

# no thread by name or thread_id
if thread is None:
data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias}
thread = self._call_api("post", "/threads", data).json()
self.thread_id = thread["thread_id"]
self.alias = thread.get("alias")

return thread
if self._thread is None:
thread = None
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID")

if self.thread_id is not None:
res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False)
if res.status_code == 200:
thread = res.json()

# use name as 'alias' to get thread
if thread is None and self.alias is not None:
res = self._call_api("get", f"/threads?alias={self.alias}").json()
if res.get("threads"):
thread = res["threads"][0]
self.thread_id = thread.get("thread_id")

# no thread by name or thread_id
if thread is None:
data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias}
thread = self._call_api("post", "/threads", data).json()
self.thread_id = thread["thread_id"]
self.alias = thread.get("alias")

self._thread = thread

return self._thread # pyright: ignore[reportReturnType]

def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
# serialize the run artifacts to json strings
Expand All @@ -109,6 +113,7 @@ def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
# all old Messages are replaced with the new ones
thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id
self._call_api("patch", f"/threads/{thread_id}", body)
self._thread = None

def load(self) -> tuple[list[Run], dict[str, Any]]:
from griptape.memory.structure import Run
Expand Down

0 comments on commit baf8e54

Please sign in to comment.