From 670616f2fa8d6da154b01710051bb1dd6e7dae7f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 23 Apr 2024 14:11:59 -0700 Subject: [PATCH] Pull cloud url from env var (#756) --- .../griptape_cloud_event_listener_driver.py | 13 ++++++++++++- .../test_griptape_cloud_event_listener_driver.py | 15 +++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py index 51b62aeac..7f52b3519 100644 --- a/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py +++ b/griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py @@ -12,7 +12,18 @@ @define class GriptapeCloudEventListenerDriver(BaseEventListenerDriver): - base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) + """Driver for publishing events to Griptape Cloud. + + Attributes: + base_url: The base URL of Griptape Cloud. Defaults to the GT_CLOUD_BASE_URL environment variable. + api_key: The API key to authenticate with Griptape Cloud. + headers: The headers to use when making requests to Griptape Cloud. Defaults to include the Authorization header. + run_id: The ID of the Structure Run to publish events to. Defaults to the GT_CLOUD_RUN_ID environment variable. + """ + + base_url: str = field( + default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True + ) api_key: str = field(kw_only=True) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True diff --git a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py index c459fc8a2..bbffa1059 100644 --- a/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_griptape_cloud_event_listener_driver.py @@ -1,8 +1,11 @@ +import os from unittest.mock import Mock -from pytest import fixture + import pytest -from tests.mocks.mock_event import MockEvent +from pytest import fixture + from griptape.drivers.event_listener.griptape_cloud_event_listener_driver import GriptapeCloudEventListenerDriver +from tests.mocks.mock_event import MockEvent class TestGriptapeCloudEventListenerDriver: @@ -17,17 +20,21 @@ def mock_post(self, mocker): @fixture() def driver(self): - return GriptapeCloudEventListenerDriver(api_key="foo bar", run_id="baz") + os.environ["GT_CLOUD_BASE_URL"] = "https://cloud123.griptape.ai" + + return GriptapeCloudEventListenerDriver(api_key="foo bar", run_id="bar baz") def test_init(self, driver): assert driver + assert driver.api_key == "foo bar" + assert driver.run_id == "bar baz" def test_try_publish_event(self, mock_post, driver): event = MockEvent() driver.try_publish_event(event=event) mock_post.assert_called_once_with( - url=f"https://cloud.griptape.ai/api/structure-runs/{driver.run_id}/events", + url="https://cloud123.griptape.ai/api/structure-runs/bar baz/events", json=event.to_dict(), headers={"Authorization": "Bearer foo bar"}, )