Skip to content

Commit

Permalink
Add thread alias
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Oct 8, 2024
1 parent 59a0a59 commit 9e4607a
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
Attributes:
thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to
retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. If that is not set, a new Thread will be
created.
retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`.
alias: The alias of the Thread to store the conversation memory in.
base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable
`GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will
Expand All @@ -33,7 +33,11 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
ValueError: If `api_key` is not provided.
"""

thread_id: str = field(
thread_id: Optional[str] = field(
default=None,
metadata={"serializable": True},
)
alias: Optional[str] = field(
default=None,
metadata={"serializable": True},
)
Expand All @@ -46,16 +50,40 @@ class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
init=False,
)

def __attrs_post_init__(self) -> None:
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID", self._get_thread_id())

@api_key.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
if value is None:
raise ValueError(f"{self.__class__.__name__} requires an API key")
return value

@property
def thread(self) -> dict:
"""Try to get the Thread by ID, alias, or create a new one."""
thread = None
if self.thread_id is None:
self.thread_id = os.getenv("GT_CLOUD_THREAD_ID")

if self.thread_id is not None:
res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False)
if res.status_code == 200:
thread = res.json()

# use name as 'alias' to get thread
if thread is None and self.alias is not None:
res = self._call_api("get", f"/threads?alias={self.alias}").json()
if res.get("threads"):
thread = res["threads"][0]
self.thread_id = thread.get("thread_id")

# no thread by name or thread_id
if thread is None:
data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias}
thread = self._call_api("post", "/threads", data).json()
self.thread_id = thread["thread_id"]
self.alias = thread.get("alias")

return thread

def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
# serialize the run artifacts to json strings
messages = [
Expand All @@ -79,25 +107,19 @@ def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:

# patch the Thread with the new messages and metadata
# all old Messages are replaced with the new ones
response = requests.patch(
self._get_url(f"/threads/{self.thread_id}"),
json=body,
headers=self.headers,
)
response.raise_for_status()
thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id
self._call_api("patch", f"/threads/{thread_id}", body)

def load(self) -> tuple[list[Run], dict[str, Any]]:
from griptape.memory.structure import Run

thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id

# get the Messages from the Thread
messages_response = requests.get(self._get_url(f"/threads/{self.thread_id}/messages"), headers=self.headers)
messages_response.raise_for_status()
messages_response = messages_response.json()
messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json()

# retrieve the Thread to get the metadata
thread_response = requests.get(self._get_url(f"/threads/{self.thread_id}"), headers=self.headers)
thread_response.raise_for_status()
thread_response = thread_response.json()
thread_response = self._call_api("get", f"/threads/{thread_id}").json()

runs = [
Run(
Expand All @@ -110,11 +132,14 @@ def load(self) -> tuple[list[Run], dict[str, Any]]:
]
return runs, thread_response.get("metadata", {})

def _get_thread_id(self) -> str:
res = requests.post(self._get_url("/threads"), json={"name": uuid.uuid4().hex}, headers=self.headers)
res.raise_for_status()
return res.json().get("thread_id")

def _get_url(self, path: str) -> str:
path = path.lstrip("/")
return urljoin(self.base_url, f"/api/{path}")

def _call_api(
self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True
) -> requests.Response:
res = requests.request(method, self._get_url(path), json=json, headers=self.headers)
if raise_for_status:
res.raise_for_status()
return res
Original file line number Diff line number Diff line change
Expand Up @@ -13,61 +13,70 @@
class TestGriptapeCloudConversationMemoryDriver:
@pytest.fixture(autouse=True)
def _mock_requests(self, mocker):
def get(*args, **kwargs):
if str(args[0]).endswith("/messages"):
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"messages": [
{
"message_id": "123",
"input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}',
"output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}',
"index": 0,
"metadata": {"run_id": "1234"},
}
]
},
)
else:
thread_id = args[0].split("/")[-1]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"metadata": {"foo": "bar"},
"name": "test",
"thread_id": "test_metadata",
}
if thread_id == "test_metadata"
else {"name": "test", "thread_id": "test"},
)

mocker.patch(
"requests.get",
side_effect=get,
)

def post(*args, **kwargs):
if str(args[0]).endswith("/threads"):
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"thread_id": "test", "name": "test"},
)
def request(*args, **kwargs):
if args[0] == "get":
if "/messages" in str(args[1]):
thread_id = args[1].split("/")[-2]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {
"messages": [
{
"message_id": f"{thread_id}_message",
"input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}',
"output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}',
"metadata": {"run_id": "1234"},
}
]
}
if thread_id != "no_messages"
else {"messages": []},
status_code=200,
)
elif "/threads/" in str(args[1]):
thread_id = args[1].split("/")[-1]
return mocker.Mock(
# raise for status if thread_id is == not_found
raise_for_status=lambda: None if thread_id != "not_found" else ValueError,
json=lambda: {
"metadata": {"foo": "bar"},
"name": "test",
"thread_id": thread_id,
},
status_code=200 if thread_id != "not_found" else 404,
)
elif "/threads?alias=" in str(args[1]):
alias = args[1].split("=")[-1]
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"threads": [{"thread_id": alias, "alias": alias, "metadata": {"foo": "bar"}}]}
if alias != "not_found"
else {"threads": []},
status_code=200,
)
else:
return mocker.Mock()
elif args[0] == "post":
if str(args[1]).endswith("/threads"):
body = kwargs["json"]
body["thread_id"] = "test"
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: body,
)
else:
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"message_id": "test"},
)
else:
return mocker.Mock(
raise_for_status=lambda: None,
json=lambda: {"message_id": "test"},
)

mocker.patch(
"requests.post",
side_effect=post,
)
mocker.patch(
"requests.patch",
return_value=mocker.Mock(
raise_for_status=lambda: None,
),
"requests.request",
side_effect=request,
)

@pytest.fixture()
Expand All @@ -80,12 +89,22 @@ def test_no_api_key(self):

def test_thread_id(self):
driver = GriptapeCloudConversationMemoryDriver(api_key="test")
assert driver.thread_id is None
assert driver.thread.get("thread_id") == "test"
assert driver.thread_id == "test"
os.environ["GT_CLOUD_THREAD_ID"] = "test_env"
driver = GriptapeCloudConversationMemoryDriver(api_key="test")
assert driver.thread_id is None
assert driver.thread.get("thread_id") == "test_env"
assert driver.thread_id == "test_env"
driver = GriptapeCloudConversationMemoryDriver(api_key="test", thread_id="test_init")
assert driver.thread_id == "test_init"
os.environ.pop("GT_CLOUD_THREAD_ID")

def test_thread_alias(self):
driver = GriptapeCloudConversationMemoryDriver(api_key="test", alias="test")
assert driver.thread_id is None
assert driver.thread.get("thread_id") == "test"
assert driver.thread_id == "test"
assert driver.alias == "test"

def test_store(self, driver: GriptapeCloudConversationMemoryDriver):
runs = [
Expand All @@ -98,8 +117,4 @@ def test_load(self, driver):
runs, metadata = driver.load()
assert len(runs) == 1
assert runs[0].id == "1234"
assert metadata == {}
driver.thread_id = "test_metadata"
runs, metadata = driver.load()
assert len(runs) == 1
assert metadata == {"foo": "bar"}

0 comments on commit 9e4607a

Please sign in to comment.