Skip to content

Commit

Permalink
Refactor/naming (#1078)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Aug 19, 2024
1 parent 0d19b0a commit f08f0c3
Show file tree
Hide file tree
Showing 69 changed files with 251 additions and 242 deletions.
6 changes: 3 additions & 3 deletions docs/examples/src/multiple_agent_shared_memory_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
vector_path=MONGODB_VECTOR_PATH,
)

config.drivers = AzureOpenAiDriverConfig(
config.driver_config = AzureOpenAiDriverConfig(
azure_endpoint=AZURE_OPENAI_ENDPOINT_1,
vector_store=mongo_driver,
embedding=embedding_driver,
vector_store_driver=mongo_driver,
embedding_driver=embedding_driver,
)

loader = Agent(
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/src/talk_to_a_video_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from griptape.config.drivers import GoogleDriverConfig
from griptape.structures import Agent

config.drivers = GoogleDriverConfig()
config.driver_config = GoogleDriverConfig()

video_file = genai.upload_file(path="tests/resources/griptape-comfyui.mp4")
while video_file.state.name == "PROCESSING":
Expand Down
12 changes: 6 additions & 6 deletions docs/griptape-framework/drivers/src/embedding_drivers_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from griptape.structures import Agent
from griptape.tools import PromptSummaryTool, WebScraperTool

config.drivers = DriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4o"),
embedding=VoyageAiEmbeddingDriver(),
config.driver_config = DriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
embedding_driver=VoyageAiEmbeddingDriver(),
)

config.drivers = DriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4o"),
embedding=VoyageAiEmbeddingDriver(),
config.driver_config = DriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
embedding_driver=VoyageAiEmbeddingDriver(),
)

agent = Agent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from griptape.rules import Rule
from griptape.structures import Agent

config.drivers = DriverConfig(prompt=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7))
config.driver_config = DriverConfig(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo", temperature=0.7))
event_bus.add_event_listeners(
[
EventListener(
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from griptape.config.drivers import OpenAiDriverConfig
from griptape.structures import Agent

config.drivers = OpenAiDriverConfig()
config.driver_config = OpenAiDriverConfig()

agent = Agent()
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from griptape.config.drivers import AzureOpenAiDriverConfig
from griptape.structures import Agent

config.drivers = AzureOpenAiDriverConfig(
config.driver_config = AzureOpenAiDriverConfig(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_3"], api_key=os.environ["AZURE_OPENAI_API_KEY_3"]
)

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from griptape.config.drivers import AmazonBedrockDriverConfig
from griptape.structures import Agent

config.drivers = AmazonBedrockDriverConfig(
config.driver_config = AmazonBedrockDriverConfig(
session=boto3.Session(
region_name=os.environ["AWS_DEFAULT_REGION"],
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from griptape.config.drivers import GoogleDriverConfig
from griptape.structures import Agent

config.drivers = GoogleDriverConfig()
config.driver_config = GoogleDriverConfig()

agent = Agent()
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from griptape.config.drivers import AnthropicDriverConfig
from griptape.structures import Agent

config.drivers = AnthropicDriverConfig()
config.driver_config = AnthropicDriverConfig()

agent = Agent()
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from griptape.config.drivers import CohereDriverConfig
from griptape.structures import Agent

config.drivers = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])
config.driver_config = CohereDriverConfig(api_key=os.environ["COHERE_API_KEY"])

agent = Agent()
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/src/config_7.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from griptape.drivers import AnthropicPromptDriver
from griptape.structures import Agent

config.drivers = DriverConfig(
prompt=AnthropicPromptDriver(
config.driver_config = DriverConfig(
prompt_driver=AnthropicPromptDriver(
model="claude-3-sonnet-20240229",
api_key=os.environ["ANTHROPIC_API_KEY"],
)
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/config_8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
}
custom_config = AmazonBedrockDriverConfig.from_dict(dict_config)

config.drivers = custom_config
config.driver_config = custom_config

agent = Agent()
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/src/config_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from griptape.config.logging import TruncateLoggingFilter
from griptape.structures import Agent

config.drivers = OpenAiDriverConfig()
config.driver_config = OpenAiDriverConfig()

logger = logging.getLogger(config.logging.logger_name)
logger = logging.getLogger(config.logging_config.logger_name)
logger.setLevel(logging.ERROR)
logger.addFilter(TruncateLoggingFilter(max_log_length=100))

Expand Down
8 changes: 4 additions & 4 deletions docs/griptape-framework/structures/src/task_memory_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from griptape.structures import Agent
from griptape.tools import FileManagerTool, QueryTool, WebScraperTool

config.drivers = OpenAiDriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4"),
config.driver_config = OpenAiDriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
)

config.drivers = OpenAiDriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4"),
config.driver_config = OpenAiDriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
)

vector_store_driver = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver())
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-tools/official-tools/src/rest_api_tool_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from griptape.tasks import ToolkitTask
from griptape.tools import RestApiTool

config.drivers = DriverConfig(
prompt=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1),
config.driver_config = DriverConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", temperature=0.1),
)

posts_client = RestApiTool(
Expand Down
8 changes: 4 additions & 4 deletions griptape/config/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

@define(kw_only=True)
class BaseConfig(SerializableMixin, ABC):
_logging: Optional[LoggingConfig] = field(alias="logging")
_drivers: Optional[BaseDriverConfig] = field(alias="drivers")
_logging_config: Optional[LoggingConfig] = field(alias="logging")
_driver_config: Optional[BaseDriverConfig] = field(alias="drivers")

def reset(self) -> None:
self._logging = None
self._drivers = None
self._logging_config = None
self._driver_config = None
35 changes: 12 additions & 23 deletions griptape/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from attrs import define, field

from griptape.utils.decorators import lazy_property

from .base_config import BaseConfig
from .drivers.openai_driver_config import OpenAiDriverConfig
from .logging.logging_config import LoggingConfig
Expand All @@ -14,29 +16,16 @@

@define(kw_only=True)
class _Config(BaseConfig):
_logging: Optional[LoggingConfig] = field(default=None, alias="logging")
_drivers: Optional[BaseDriverConfig] = field(default=None, alias="drivers")

@property
def drivers(self) -> BaseDriverConfig:
"""Lazily instantiates the drivers configuration to avoid client errors like missing API key."""
if self._drivers is None:
self._drivers = OpenAiDriverConfig()
return self._drivers

@drivers.setter
def drivers(self, drivers: BaseDriverConfig) -> None:
self._drivers = drivers

@property
def logging(self) -> LoggingConfig:
if self._logging is None:
self._logging = LoggingConfig()
return self._logging

@logging.setter
def logging(self, logging: LoggingConfig) -> None:
self._logging = logging
_logging_config: Optional[LoggingConfig] = field(default=None, alias="logging")
_driver_config: Optional[BaseDriverConfig] = field(default=None, alias="drivers")

@lazy_property()
def driver_config(self) -> BaseDriverConfig:
return OpenAiDriverConfig()

@lazy_property()
def logging_config(self) -> LoggingConfig:
return LoggingConfig()


config = _Config()
10 changes: 5 additions & 5 deletions griptape/config/drivers/amazon_bedrock_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,31 @@ class AmazonBedrockDriverConfig(DriverConfig):
)

@lazy_property()
def prompt(self) -> AmazonBedrockPromptDriver:
def prompt_driver(self) -> AmazonBedrockPromptDriver:
return AmazonBedrockPromptDriver(session=self.session, model="anthropic.claude-3-5-sonnet-20240620-v1:0")

@lazy_property()
def embedding(self) -> AmazonBedrockTitanEmbeddingDriver:
def embedding_driver(self) -> AmazonBedrockTitanEmbeddingDriver:
return AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1")

@lazy_property()
def image_generation(self) -> AmazonBedrockImageGenerationDriver:
def image_generation_driver(self) -> AmazonBedrockImageGenerationDriver:
return AmazonBedrockImageGenerationDriver(
session=self.session,
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
)

@lazy_property()
def image_query(self) -> AmazonBedrockImageQueryDriver:
def image_query_driver(self) -> AmazonBedrockImageQueryDriver:
return AmazonBedrockImageQueryDriver(
session=self.session,
model="anthropic.claude-3-5-sonnet-20240620-v1:0",
image_query_model_driver=BedrockClaudeImageQueryModelDriver(),
)

@lazy_property()
def vector_store(self) -> LocalVectorStoreDriver:
def vector_store_driver(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(session=self.session, model="amazon.titan-embed-text-v1")
)
8 changes: 4 additions & 4 deletions griptape/config/drivers/anthropic_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
@define
class AnthropicDriverConfig(DriverConfig):
@lazy_property()
def prompt(self) -> AnthropicPromptDriver:
def prompt_driver(self) -> AnthropicPromptDriver:
return AnthropicPromptDriver(model="claude-3-5-sonnet-20240620")

@lazy_property()
def embedding(self) -> VoyageAiEmbeddingDriver:
def embedding_driver(self) -> VoyageAiEmbeddingDriver:
return VoyageAiEmbeddingDriver(model="voyage-large-2")

@lazy_property()
def vector_store(self) -> LocalVectorStoreDriver:
def vector_store_driver(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2"))

@lazy_property()
def image_query(self) -> AnthropicImageQueryDriver:
def image_query_driver(self) -> AnthropicImageQueryDriver:
return AnthropicImageQueryDriver(model="claude-3-5-sonnet-20240620")
10 changes: 5 additions & 5 deletions griptape/config/drivers/azure_openai_driver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AzureOpenAiDriverConfig(DriverConfig):
api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})

@lazy_property()
def prompt(self) -> AzureOpenAiChatPromptDriver:
def prompt_driver(self) -> AzureOpenAiChatPromptDriver:
return AzureOpenAiChatPromptDriver(
model="gpt-4o",
azure_endpoint=self.azure_endpoint,
Expand All @@ -51,7 +51,7 @@ def prompt(self) -> AzureOpenAiChatPromptDriver:
)

@lazy_property()
def embedding(self) -> AzureOpenAiEmbeddingDriver:
def embedding_driver(self) -> AzureOpenAiEmbeddingDriver:
return AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
azure_endpoint=self.azure_endpoint,
Expand All @@ -61,7 +61,7 @@ def embedding(self) -> AzureOpenAiEmbeddingDriver:
)

@lazy_property()
def image_generation(self) -> AzureOpenAiImageGenerationDriver:
def image_generation_driver(self) -> AzureOpenAiImageGenerationDriver:
return AzureOpenAiImageGenerationDriver(
model="dall-e-2",
azure_endpoint=self.azure_endpoint,
Expand All @@ -72,7 +72,7 @@ def image_generation(self) -> AzureOpenAiImageGenerationDriver:
)

@lazy_property()
def image_query(self) -> AzureOpenAiImageQueryDriver:
def image_query_driver(self) -> AzureOpenAiImageQueryDriver:
return AzureOpenAiImageQueryDriver(
model="gpt-4o",
azure_endpoint=self.azure_endpoint,
Expand All @@ -82,7 +82,7 @@ def image_query(self) -> AzureOpenAiImageQueryDriver:
)

@lazy_property()
def vector_store(self) -> LocalVectorStoreDriver:
def vector_store_driver(self) -> LocalVectorStoreDriver:
return LocalVectorStoreDriver(
embedding_driver=AzureOpenAiEmbeddingDriver(
model="text-embedding-3-small",
Expand Down
Loading

0 comments on commit f08f0c3

Please sign in to comment.