Skip to content

Commit

Permalink
Update dependencies, fix emergent issues (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Aug 15, 2024
1 parent babc56a commit adb660e
Show file tree
Hide file tree
Showing 9 changed files with 1,038 additions and 999 deletions.
3 changes: 2 additions & 1 deletion griptape/config/drivers/amazon_bedrock_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AmazonBedrockTitanEmbeddingDriver,
BaseEmbeddingDriver,
BaseImageGenerationDriver,
BaseImageQueryDriver,
BasePromptDriver,
BaseVectorStoreDriver,
BedrockClaudeImageQueryModelDriver,
Expand Down Expand Up @@ -63,7 +64,7 @@ class AmazonBedrockDriverConfig(DriverConfig):
kw_only=True,
metadata={"serializable": True},
)
image_query: BaseImageGenerationDriver = field(
image_query: BaseImageQueryDriver = field(
default=Factory(
lambda self: AmazonBedrockImageQueryDriver(
session=self.session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@define
class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
api_type: str = field(default=openai.api_type, kw_only=True)
api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from typing import Optional
from urllib.parse import urljoin

import requests
Expand All @@ -25,12 +26,14 @@ class GriptapeCloudEventListenerDriver(BaseEventListenerDriver):
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
kw_only=True,
)
api_key: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True)
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True)
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
kw_only=True,
)
structure_run_id: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True)
structure_run_id: Optional[str] = field(
default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True
)

@structure_run_id.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_run_id(self, _: Attribute, structure_run_id: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class OpenAiImageGenerationDriver(BaseImageGenerationDriver):
a base64 encoded image in a JSON object.
"""

api_type: str = field(default=openai.api_type, kw_only=True)
api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
Expand Down
2 changes: 1 addition & 1 deletion griptape/drivers/image_query/openai_image_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@define
class OpenAiImageQueryDriver(BaseImageQueryDriver):
model: str = field(kw_only=True, metadata={"serializable": True})
api_type: str = field(default=openai.api_type, kw_only=True)
api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ class GriptapeCloudObservabilityDriver(OpenTelemetryObservabilityDriver):
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True
)
api_key: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True)
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True)
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
)
structure_run_id: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True)
structure_run_id: Optional[str] = field(
default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True
)
span_processor: SpanProcessor = field(
default=Factory(
lambda self: import_optional_dependency("opentelemetry.sdk.trace.export").BatchSpanProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver):
metadata={"serializable": True},
)
format: Literal["mp3", "opus", "aac", "flac"] = field(default="mp3", kw_only=True, metadata={"serializable": True})
api_type: str = field(default=openai.api_type, kw_only=True)
api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True)
Expand Down
Loading

0 comments on commit adb660e

Please sign in to comment.