diff --git a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py index 3aac74090d..046032c5ff 100644 --- a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py @@ -22,8 +22,8 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver): Attributes: thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to - retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. If that is not set, a new Thread will be - created. + retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. + alias: The alias of the Thread to store the conversation memory in. base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`. api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will @@ -33,7 +33,11 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver): ValueError: If `api_key` is not provided. """ - thread_id: str = field( + thread_id: Optional[str] = field( + default=None, + metadata={"serializable": True}, + ) + alias: Optional[str] = field( default=None, metadata={"serializable": True}, ) @@ -46,16 +50,40 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver): init=False, ) - def __attrs_post_init__(self) -> None: - if self.thread_id is None: - self.thread_id = os.getenv("GT_CLOUD_THREAD_ID", self._get_thread_id()) - @api_key.validator # pyright: ignore[reportAttributeAccessIssue] def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an API key") return value + @property + def thread(self) -> dict: + """Try to get the Thread by ID, alias, or create a new one.""" + thread = None + if self.thread_id is None: + self.thread_id = os.getenv("GT_CLOUD_THREAD_ID") + + if self.thread_id is not None: + res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False) + if res.status_code == 200: + thread = res.json() + + # use name as 'alias' to get thread + if thread is None and self.alias is not None: + res = self._call_api("get", f"/threads?alias={self.alias}").json() + if res.get("threads"): + thread = res["threads"][0] + self.thread_id = thread.get("thread_id") + + # no thread by name or thread_id + if thread is None: + data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias} + thread = self._call_api("post", "/threads", data).json() + self.thread_id = thread["thread_id"] + self.alias = thread.get("alias") + + return thread + def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: # serialize the run artifacts to json strings messages = [ @@ -79,25 +107,19 @@ def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: # patch the Thread with the new messages and metadata # all old Messages are replaced with the new ones - response = requests.patch( - self._get_url(f"/threads/{self.thread_id}"), - json=body, - headers=self.headers, - ) - response.raise_for_status() + thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id + self._call_api("patch", f"/threads/{thread_id}", body) def load(self) -> tuple[list[Run], dict[str, Any]]: from griptape.memory.structure import Run + thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id + # get the Messages from the Thread - messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers) - messages_response.raise_for_status() - messages_response = messages_response.json() + messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json() # retrieve the Thread to get the metadata - thread_response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers) - thread_response.raise_for_status() - thread_response = thread_response.json() + thread_response = self._call_api("get", f"/threads/{thread_id}").json() runs = [ Run( @@ -110,11 +132,14 @@ def load(self) -> tuple[list[Run], dict[str, Any]]: ] return runs, thread_response.get("metadata", {}) - def _get_thread_id(self) -> str: - res = requests.post(self._get_url("/threads"), json={"name": uuid.uuid4().hex}, headers=self.headers) - res.raise_for_status() - return res.json().get("thread_id") - def _get_url(self, path: str) -> str: path = path.lstrip("/") return urljoin(self.base_url, f"/api/{path}") + + def _call_api( + self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True + ) -> requests.Response: + res = requests.request(method, self._get_url(path), json=json, headers=self.headers) + if raise_for_status: + res.raise_for_status() + return res diff --git a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py index dccdc9fd05..0c76d6ecdb 100644 --- a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py @@ -13,61 +13,70 @@ class TestGriptapeCloudConversationMemoryDriver: @pytest.fixture(autouse=True) def _mock_requests(self, mocker): - def get(*args, **kwargs): - if str(args[0]).endswith("/messages"): - return mocker.Mock( - raise_for_status=lambda: None, - json=lambda: { - "messages": [ - { - "message_id": "123", - "input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}', - "output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}', - "index": 0, - "metadata": {"run_id": "1234"}, - } - ] - }, - ) - else: - thread_id = args[0].split("/")[-1] - return mocker.Mock( - raise_for_status=lambda: None, - json=lambda: { - "metadata": {"foo": "bar"}, - "name": "test", - "thread_id": "test_metadata", - } - if thread_id == "test_metadata" - else {"name": "test", "thread_id": "test"}, - ) - - mocker.patch( - "requests.get", - side_effect=get, - ) - - def post(*args, **kwargs): - if str(args[0]).endswith("/threads"): - return mocker.Mock( - raise_for_status=lambda: None, - json=lambda: {"thread_id": "test", "name": "test"}, - ) + def request(*args, **kwargs): + if args[0] == "get": + if "/messages" in str(args[1]): + thread_id = args[1].split("/")[-2] + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: { + "messages": [ + { + "message_id": f"{thread_id}_message", + "input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}', + "output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}', + "metadata": {"run_id": "1234"}, + } + ] + } + if thread_id != "no_messages" + else {"messages": []}, + status_code=200, + ) + elif "/threads/" in str(args[1]): + thread_id = args[1].split("/")[-1] + return mocker.Mock( + # raise for status if thread_id is == not_found + raise_for_status=lambda: None if thread_id != "not_found" else ValueError, + json=lambda: { + "metadata": {"foo": "bar"}, + "name": "test", + "thread_id": thread_id, + }, + status_code=200 if thread_id != "not_found" else 404, + ) + elif "/threads?alias=" in str(args[1]): + alias = args[1].split("=")[-1] + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: {"threads": [{"thread_id": alias, "alias": alias, "metadata": {"foo": "bar"}}]} + if alias != "not_found" + else {"threads": []}, + status_code=200, + ) + else: + return mocker.Mock() + elif args[0] == "post": + if str(args[1]).endswith("/threads"): + body = kwargs["json"] + body["thread_id"] = "test" + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: body, + ) + else: + return mocker.Mock( + raise_for_status=lambda: None, + json=lambda: {"message_id": "test"}, + ) else: return mocker.Mock( raise_for_status=lambda: None, - json=lambda: {"message_id": "test"}, ) mocker.patch( - "requests.post", - side_effect=post, - ) - mocker.patch( - "requests.patch", - return_value=mocker.Mock( - raise_for_status=lambda: None, - ), + "requests.request", + side_effect=request, ) @pytest.fixture() @@ -80,12 +89,22 @@ def test_no_api_key(self): def test_thread_id(self): driver = GriptapeCloudConversationMemoryDriver(api_key="test") + assert driver.thread_id is None + assert driver.thread.get("thread_id") == "test" assert driver.thread_id == "test" os.environ["GT_CLOUD_THREAD_ID"] = "test_env" driver = GriptapeCloudConversationMemoryDriver(api_key="test") + assert driver.thread_id is None + assert driver.thread.get("thread_id") == "test_env" assert driver.thread_id == "test_env" - driver = GriptapeCloudConversationMemoryDriver(api_key="test", thread_id="test_init") - assert driver.thread_id == "test_init" + os.environ.pop("GT_CLOUD_THREAD_ID") + + def test_thread_alias(self): + driver = GriptapeCloudConversationMemoryDriver(api_key="test", alias="test") + assert driver.thread_id is None + assert driver.thread.get("thread_id") == "test" + assert driver.thread_id == "test" + assert driver.alias == "test" def test_store(self, driver: GriptapeCloudConversationMemoryDriver): runs = [ @@ -98,8 +117,4 @@ def test_load(self, driver): runs, metadata = driver.load() assert len(runs) == 1 assert runs[0].id == "1234" - assert metadata == {} - driver.thread_id = "test_metadata" - runs, metadata = driver.load() - assert len(runs) == 1 assert metadata == {"foo": "bar"}