Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pep8-naming ruff rule #993

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: `BaseVectorStoreDriver.query` optional arguments are now keyword-only arguments.
- **BREAKING**: `EventListener.publish_event`'s `flush` argument is now a keyword-only argument.
- **BREAKING**: `BaseEventListenerDriver.publish_event`'s `flush` argument is now a keyword-only argument.
- **BREAKING**: Renamed `DummyException` to `DummyError` for pep8 naming compliance.
- Removed unnecessary `transformers` dependency in `drivers-prompt-huggingface` extra.
- Removed unnecessary `huggingface-hub` dependency in `drivers-prompt-huggingface-pipeline` extra.

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ agent = Agent(config=CohereStructureConfig(api_key=os.environ["COHERE_API_KEY"])
### Custom Configs

You can create your own [StructureConfig](../../reference/griptape/config/structure_config.md) by overriding relevant Drivers.
The [StructureConfig](../../reference/griptape/config/structure_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyException](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden.
The [StructureConfig](../../reference/griptape/config/structure_config.md) class includes "Dummy" Drivers for all types, which throw a [DummyError](../../reference/griptape/exceptions/dummy_exception.md) if invoked without being overridden.
This approach ensures that you are informed through clear error messages if you attempt to use Structures without proper Driver configurations.

```python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attrs import define, field

from griptape.drivers import BaseAudioTranscriptionDriver
from griptape.exceptions import DummyException
from griptape.exceptions import DummyError

if TYPE_CHECKING:
from griptape.artifacts import AudioArtifact, TextArtifact
Expand All @@ -16,4 +16,4 @@
model: str = field(init=False)

def try_run(self, audio: AudioArtifact, prompts: Optional[list] = None) -> TextArtifact:
raise DummyException(__class__.__name__, "try_transcription")
raise DummyError(__class__.__name__, "try_transcription")

Check warning on line 19 in griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py#L19

Added line #L19 was not covered by tests
4 changes: 2 additions & 2 deletions griptape/drivers/embedding/dummy_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from attrs import define, field

from griptape.drivers import BaseEmbeddingDriver
from griptape.exceptions import DummyException
from griptape.exceptions import DummyError


@define
class DummyEmbeddingDriver(BaseEmbeddingDriver):
model: None = field(init=False, default=None, kw_only=True)

def try_embed_chunk(self, chunk: str) -> list[float]:
raise DummyException(__class__.__name__, "try_embed_chunk")
raise DummyError(__class__.__name__, "try_embed_chunk")
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attrs import define, field

from griptape.drivers import BaseImageGenerationDriver
from griptape.exceptions import DummyException
from griptape.exceptions import DummyError

if TYPE_CHECKING:
from griptape.artifacts import ImageArtifact
Expand All @@ -16,15 +16,15 @@ class DummyImageGenerationDriver(BaseImageGenerationDriver):
model: None = field(init=False, default=None, kw_only=True)

def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
raise DummyException(__class__.__name__, "try_text_to_image")
raise DummyError(__class__.__name__, "try_text_to_image")

def try_image_variation(
self,
prompts: list[str],
image: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise DummyException(__class__.__name__, "try_image_variation")
raise DummyError(__class__.__name__, "try_image_variation")

def try_image_inpainting(
self,
Expand All @@ -33,7 +33,7 @@ def try_image_inpainting(
mask: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise DummyException(__class__.__name__, "try_image_inpainting")
raise DummyError(__class__.__name__, "try_image_inpainting")

def try_image_outpainting(
self,
Expand All @@ -42,4 +42,4 @@ def try_image_outpainting(
mask: ImageArtifact,
negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
raise DummyException(__class__.__name__, "try_image_outpainting")
raise DummyError(__class__.__name__, "try_image_outpainting")
4 changes: 2 additions & 2 deletions griptape/drivers/image_query/dummy_image_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attrs import define, field

from griptape.drivers import BaseImageQueryDriver
from griptape.exceptions import DummyException
from griptape.exceptions import DummyError

if TYPE_CHECKING:
from griptape.artifacts import ImageArtifact, TextArtifact
Expand All @@ -17,4 +17,4 @@ class DummyImageQueryDriver(BaseImageQueryDriver):
max_tokens: None = field(init=False, default=None, kw_only=True)

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
raise DummyException(__class__.__name__, "try_query")
raise DummyError(__class__.__name__, "try_query")
6 changes: 3 additions & 3 deletions griptape/drivers/prompt/dummy_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attrs import Factory, define, field

from griptape.drivers import BasePromptDriver
from griptape.exceptions import DummyException
from griptape.exceptions import DummyError
from griptape.tokenizers import DummyTokenizer

if TYPE_CHECKING:
Expand All @@ -20,7 +20,7 @@
tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True)

def try_run(self, prompt_stack: PromptStack) -> Message:
raise DummyException(__class__.__name__, "try_run")
raise DummyError(__class__.__name__, "try_run")

def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
raise DummyException(__class__.__name__, "try_stream")
raise DummyError(__class__.__name__, "try_stream")

Check warning on line 26 in griptape/drivers/prompt/dummy_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/dummy_prompt_driver.py#L26

Added line #L26 was not covered by tests
23 changes: 11 additions & 12 deletions griptape/drivers/prompt/google_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,18 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
)

def _base_params(self, prompt_stack: PromptStack) -> dict:
GenerationConfig = import_optional_dependency("google.generativeai.types").GenerationConfig
ContentDict = import_optional_dependency("google.generativeai.types").ContentDict
Part = import_optional_dependency("google.generativeai.protos").Part
types = import_optional_dependency("google.generativeai.types")
protos = import_optional_dependency("google.generativeai.protos")
Comment on lines -116 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I liked it better before, but 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this one is a little weird, but I think I'd rather go for consistency.


system_messages = prompt_stack.system_messages
if system_messages:
self.model_client._system_instruction = ContentDict(
self.model_client._system_instruction = types.ContentDict(
role="system",
parts=[Part(text=system_message.to_text()) for system_message in system_messages],
parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages],
)

return {
"generation_config": GenerationConfig(
"generation_config": types.GenerationConfig(
**{
# For some reason, providing stop sequences when streaming breaks native functions
# https://github.com/google-gemini/generative-ai-python/issues/446
Expand Down Expand Up @@ -150,10 +149,10 @@ def _default_model_client(self) -> GenerativeModel:
return genai.GenerativeModel(self.model)

def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
ContentDict = import_optional_dependency("google.generativeai.types").ContentDict
types = import_optional_dependency("google.generativeai.types")

inputs = [
ContentDict(
types.ContentDict(
{
"role": self.__to_google_role(message),
"parts": [self.__to_google_message_content(content) for content in message.content],
Expand All @@ -172,7 +171,7 @@ def __to_google_role(self, message: Message) -> str:
return "user"

def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
FunctionDeclaration = import_optional_dependency("google.generativeai.types").FunctionDeclaration
types = import_optional_dependency("google.generativeai.types")

tool_declarations = []
for tool in tools:
Expand All @@ -183,7 +182,7 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
schema = schema["properties"]["values"]

schema = remove_key_in_dict_recursively(schema, "additionalProperties")
tool_declaration = FunctionDeclaration(
tool_declaration = types.FunctionDeclaration(
name=tool.to_native_tool_name(activity),
description=tool.activity_description(activity),
parameters={
Expand All @@ -198,13 +197,13 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
return tool_declarations

def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str:
ContentDict = import_optional_dependency("google.generativeai.types").ContentDict
types = import_optional_dependency("google.generativeai.types")
protos = import_optional_dependency("google.generativeai.protos")

if isinstance(content, TextMessageContent):
return content.artifact.to_text()
elif isinstance(content, ImageMessageContent):
return ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value)
return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value)
elif isinstance(content, ActionCallMessageContent):
action = content.artifact.value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attrs import define, field

from griptape.drivers import BaseTextToSpeechDriver
from griptape.exceptions import DummyException
from griptape.exceptions import DummyError

if TYPE_CHECKING:
from griptape.artifacts.audio_artifact import AudioArtifact
Expand All @@ -16,4 +16,4 @@
model: None = field(init=False, default=None, kw_only=True)

def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
raise DummyException(__class__.__name__, "try_text_to_audio")
raise DummyError(__class__.__name__, "try_text_to_audio")

Check warning on line 19 in griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py#L19

Added line #L19 was not covered by tests
12 changes: 6 additions & 6 deletions griptape/drivers/vector/dummy_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from attrs import Factory, define, field

from griptape.drivers import BaseEmbeddingDriver, BaseVectorStoreDriver, DummyEmbeddingDriver
from griptape.exceptions import DummyException
from griptape.exceptions import DummyError


@define()
Expand All @@ -17,7 +17,7 @@ class DummyVectorStoreDriver(BaseVectorStoreDriver):
)

def delete_vector(self, vector_id: str) -> None:
raise DummyException(__class__.__name__, "delete_vector")
raise DummyError(__class__.__name__, "delete_vector")

def upsert_vector(
self,
Expand All @@ -27,13 +27,13 @@ def upsert_vector(
meta: Optional[dict] = None,
**kwargs,
) -> str:
raise DummyException(__class__.__name__, "upsert_vector")
raise DummyError(__class__.__name__, "upsert_vector")

def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
raise DummyException(__class__.__name__, "load_entry")
raise DummyError(__class__.__name__, "load_entry")

def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
raise DummyException(__class__.__name__, "load_entries")
raise DummyError(__class__.__name__, "load_entries")

def query(
self,
Expand All @@ -44,4 +44,4 @@ def query(
include_vectors: bool = False,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
raise DummyException(__class__.__name__, "query")
raise DummyError(__class__.__name__, "query")
7 changes: 3 additions & 4 deletions griptape/drivers/vector/pgvector_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,14 @@ def query(
]

def default_vector_model(self) -> Any:
Vector = import_optional_dependency("pgvector.sqlalchemy").Vector
Base = declarative_base()
sqlalchemy = import_optional_dependency("pgvector.sqlalchemy")

@dataclass
class VectorModel(Base):
class VectorModel(declarative_base()):
__tablename__ = self.table_name

id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False)
vector = Column(Vector())
vector = Column(sqlalchemy.Vector())
namespace = Column(String)
meta = Column(JSON)

Expand Down
4 changes: 2 additions & 2 deletions griptape/drivers/vector/redis_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ def query(
Returns:
A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
"""
Query = import_optional_dependency("redis.commands.search.query").Query
search_query = import_optional_dependency("redis.commands.search.query")

vector = self.embedding_driver.embed_string(query)

filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*"
query_expression = (
Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]")
search_query.Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]")
.sort_by("score")
.return_fields("id", "score", "metadata", "vec_string")
.paging(0, count or 10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ class MarkdownifyWebScraperDriver(BaseWebScraperDriver):

def scrape_url(self, url: str) -> TextArtifact:
sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright
BeautifulSoup = import_optional_dependency("bs4").BeautifulSoup
MarkdownConverter = import_optional_dependency("markdownify").MarkdownConverter
bs4 = import_optional_dependency("bs4")
markdownify = import_optional_dependency("markdownify")

include_links = self.include_links

# Custom MarkdownConverter to optionally linked urls. If include_links is False only
# the text of the link is returned.
class OptionalLinksMarkdownConverter(MarkdownConverter):
class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter):
def convert_a(self, el: Any, text: str, convert_as_inline: Any) -> str:
if include_links:
return super().convert_a(el, text, convert_as_inline)
Expand Down Expand Up @@ -75,7 +75,7 @@ def skip_loading_images(route: Any) -> Any:
if not content:
raise Exception("can't access URL")

soup = BeautifulSoup(content, "html.parser")
soup = bs4.BeautifulSoup(content, "html.parser")

# Remove unwanted elements
exclude_selector = ",".join(
Expand Down
4 changes: 2 additions & 2 deletions griptape/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dummy_exception import DummyException
from .dummy_exception import DummyError

__all__ = ["DummyException"]
__all__ = ["DummyError"]
2 changes: 1 addition & 1 deletion griptape/exceptions/dummy_exception.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class DummyException(Exception):
class DummyError(Exception):
def __init__(self, dummy_class_name: str, dummy_method_name: str) -> None:
message = (
f"You have attempted to use a {dummy_class_name}'s {dummy_method_name} method. "
Expand Down
6 changes: 3 additions & 3 deletions griptape/loaders/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ class ImageLoader(BaseLoader):
}

def load(self, source: bytes, *args, **kwargs) -> ImageArtifact:
Image = import_optional_dependency("PIL.Image")
image = Image.open(BytesIO(source))
pil_image = import_optional_dependency("PIL.Image")
image = pil_image.open(BytesIO(source))

# Normalize format only if requested.
if self.format is not None:
byte_stream = BytesIO()
image.save(byte_stream, format=self.format)
image = Image.open(byte_stream)
image = pil_image.open(byte_stream)
source = byte_stream.getvalue()

image_artifact = ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height)
Expand Down
4 changes: 2 additions & 2 deletions griptape/loaders/pdf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def load(
*args,
**kwargs,
) -> ErrorArtifact | list[TextArtifact]:
PdfReader = import_optional_dependency("pypdf").PdfReader
reader = PdfReader(BytesIO(source), strict=True, password=password)
pypdf = import_optional_dependency("pypdf")
reader = pypdf.PdfReader(BytesIO(source), strict=True, password=password)
return self._text_to_artifacts("\n".join([p.extract_text() for p in reader.pages]))

def load_collection(self, sources: list[bytes], *args, **kwargs) -> dict[str, ErrorArtifact | list[TextArtifact]]:
Expand Down
20 changes: 7 additions & 13 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested:
else:
raise ValueError(f"Missing type for list field: {field_type}")
else:
FieldClass = cls.DATACLASS_TYPE_MAPPING[field_class]
field_class = cls.DATACLASS_TYPE_MAPPING[field_class]

return FieldClass(allow_none=optional)
return field_class(allow_none=optional)

@classmethod
def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]:
Expand Down Expand Up @@ -131,14 +131,6 @@ def _resolve_types(cls, attrs_cls: type) -> None:
from griptape.tools import BaseTool
from griptape.utils.import_utils import import_optional_dependency, is_dependency_installed

boto3 = import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any
Client = import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any
GenerativeModel = (
import_optional_dependency("google.generativeai").GenerativeModel
if is_dependency_installed("google.generativeai")
else Any
)

attrs.resolve_types(
attrs_cls,
localns={
Expand All @@ -163,9 +155,11 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"Run": Run,
"Sequence": Sequence,
# Third party modules
"Client": Client,
"GenerativeModel": GenerativeModel,
"boto3": boto3,
"Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
"GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
if is_dependency_installed("google.generativeai")
else Any,
"boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any,
},
)

Expand Down
Loading
Loading