From dc3d1355d80d78e7e97c567b038b069f2cae1b9b Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 24 Jul 2024 13:19:55 -0700 Subject: [PATCH] Add furb ruff rule (#1014) --- griptape/artifacts/boolean_artifact.py | 2 +- .../drivers/file_manager/local_file_manager_driver.py | 6 ++---- .../conversation/local_conversation_memory_driver.py | 11 +++++------ .../vector/amazon_opensearch_vector_store_driver.py | 2 +- .../vector/azure_mongodb_vector_store_driver.py | 4 ++-- griptape/drivers/vector/base_vector_store_driver.py | 4 ++-- griptape/drivers/vector/local_vector_store_driver.py | 5 +++-- griptape/drivers/vector/marqo_vector_store_driver.py | 4 ++-- .../vector/mongodb_atlas_vector_store_driver.py | 4 ++-- .../drivers/vector/opensearch_vector_store_driver.py | 4 ++-- .../drivers/vector/pinecone_vector_store_driver.py | 4 ++-- griptape/drivers/vector/redis_vector_store_driver.py | 2 +- griptape/loaders/email_loader.py | 2 +- griptape/mixins/media_artifact_file_output_mixin.py | 4 ++-- griptape/tasks/base_image_generation_task.py | 4 ++-- griptape/tokenizers/openai_tokenizer.py | 4 ++-- griptape/tools/audio_transcription_client/tool.py | 4 ++-- griptape/tools/base_tool.py | 2 +- griptape/tools/computer/tool.py | 5 ++--- griptape/tools/image_query_client/tool.py | 4 ++-- griptape/utils/conversation.py | 3 +-- griptape/utils/file_utils.py | 4 ++-- pyproject.toml | 2 ++ .../file_manager/test_local_file_manager_driver.py | 6 ++---- .../test_trafilatura_web_scraper_driver.py | 2 +- .../engines/summary/test_prompt_summary_engine.py | 4 ++-- tests/unit/loaders/conftest.py | 6 ++---- tests/unit/tools/test_transcription_client.py | 8 ++++++++ 28 files changed, 59 insertions(+), 57 deletions(-) diff --git a/griptape/artifacts/boolean_artifact.py b/griptape/artifacts/boolean_artifact.py index ac45a1967..5bcdfac9b 100644 --- a/griptape/artifacts/boolean_artifact.py +++ b/griptape/artifacts/boolean_artifact.py @@ -12,7 +12,7 @@ class BooleanArtifact(BaseArtifact): value: bool = field(converter=bool, metadata={"serializable": True}) @classmethod - def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: + def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact: # noqa: FBT001 """Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.""" if value is not None: if isinstance(value, str): diff --git a/griptape/drivers/file_manager/local_file_manager_driver.py b/griptape/drivers/file_manager/local_file_manager_driver.py index 4c2fcf5e6..a6f1f0726 100644 --- a/griptape/drivers/file_manager/local_file_manager_driver.py +++ b/griptape/drivers/file_manager/local_file_manager_driver.py @@ -31,16 +31,14 @@ def try_load_file(self, path: str) -> bytes: full_path = self._full_path(path) if self._is_dir(full_path): raise IsADirectoryError - with open(full_path, "rb") as file: - return file.read() + return Path(full_path).read_bytes() def try_save_file(self, path: str, value: bytes) -> None: full_path = self._full_path(path) if self._is_dir(full_path): raise IsADirectoryError os.makedirs(os.path.dirname(full_path), exist_ok=True) - with open(full_path, "wb") as file: - file.write(value) + Path(full_path).write_bytes(value) def _full_path(self, path: str) -> str: path = path.lstrip("/") diff --git a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py index d8c7e992e..8d6399e13 100644 --- a/griptape/drivers/memory/conversation/local_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/local_conversation_memory_driver.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from pathlib import Path from typing import Optional from attrs import define, field @@ -14,15 +15,13 @@ class LocalConversationMemoryDriver(BaseConversationMemoryDriver): file_path: str = field(default="griptape_memory.json", kw_only=True, metadata={"serializable": True}) def store(self, memory: BaseConversationMemory) -> None: - with open(self.file_path, "w") as file: - file.write(memory.to_json()) + Path(self.file_path).write_text(memory.to_json()) def load(self) -> Optional[BaseConversationMemory]: if not os.path.exists(self.file_path): return None - with open(self.file_path) as file: - memory = BaseConversationMemory.from_json(file.read()) + memory = BaseConversationMemory.from_json(Path(self.file_path).read_text()) - memory.driver = self + memory.driver = self - return memory + return memory diff --git a/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py b/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py index d54b98ec2..b1d881958 100644 --- a/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/amazon_opensearch_vector_store_driver.py @@ -63,7 +63,7 @@ def upsert_vector( If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ - vector_id = vector_id if vector_id else str_to_hash(str(vector)) + vector_id = vector_id or str_to_hash(str(vector)) doc = {"vector": vector, "namespace": namespace, "metadata": meta} doc.update(kwargs) if self.service == "aoss": diff --git a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py index 60e9df097..993f7a300 100644 --- a/griptape/drivers/vector/azure_mongodb_vector_store_driver.py +++ b/griptape/drivers/vector/azure_mongodb_vector_store_driver.py @@ -30,8 +30,8 @@ def query( # Using the embedding driver to convert the query string into a vector vector = self.embedding_driver.embed_string(query) - count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT - offset = offset if offset else 0 + count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT + offset = offset or 0 pipeline = [] diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index 2b00266a8..d1da78188 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -87,7 +87,7 @@ def upsert_text_artifact( else: meta["artifact"] = artifact.to_json() - vector = artifact.embedding if artifact.embedding else artifact.generate_embedding(self.embedding_driver) + vector = artifact.embedding or artifact.generate_embedding(self.embedding_driver) if isinstance(vector, list): return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) @@ -112,7 +112,7 @@ def upsert_text( self.embedding_driver.embed_string(string), vector_id=vector_id, namespace=namespace, - meta=meta if meta else {}, + meta=meta or {}, **kwargs, ) diff --git a/griptape/drivers/vector/local_vector_store_driver.py b/griptape/drivers/vector/local_vector_store_driver.py index ab59f332c..2b42b19f5 100644 --- a/griptape/drivers/vector/local_vector_store_driver.py +++ b/griptape/drivers/vector/local_vector_store_driver.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import operator import os import threading from dataclasses import asdict @@ -58,7 +59,7 @@ def upsert_vector( meta: Optional[dict] = None, **kwargs, ) -> str: - vector_id = vector_id if vector_id else utils.str_to_hash(str(vector)) + vector_id = vector_id or utils.str_to_hash(str(vector)) with self.thread_lock: self.entries[self._namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry( @@ -101,7 +102,7 @@ def query( entries_and_relatednesses = [ (entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in entries.values() ] - entries_and_relatednesses.sort(key=lambda x: x[1], reverse=True) + entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) result = [ BaseVectorStoreDriver.Entry(id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta) diff --git a/griptape/drivers/vector/marqo_vector_store_driver.py b/griptape/drivers/vector/marqo_vector_store_driver.py index 7f6b52103..caab118b8 100644 --- a/griptape/drivers/vector/marqo_vector_store_driver.py +++ b/griptape/drivers/vector/marqo_vector_store_driver.py @@ -123,7 +123,7 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0: return BaseVectorStoreDriver.Entry( id=result["_id"], - meta={k: v for k, v in result.items() if k not in ["_id"]}, + meta={k: v for k, v in result.items() if k != "_id"}, vector=result["_tensor_facets"][0]["_embedding"], ) else: @@ -190,7 +190,7 @@ def query( The list of query results. """ params = { - "limit": count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, + "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "attributes_to_retrieve": ["*"] if include_metadata else ["_id"], "filter_string": f"namespace:{namespace}" if namespace else None, } | kwargs diff --git a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py index b17aaf4e7..34b1d3a5e 100644 --- a/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py +++ b/griptape/drivers/vector/mongodb_atlas_vector_store_driver.py @@ -133,8 +133,8 @@ def query( # Using the embedding driver to convert the query string into a vector vector = self.embedding_driver.embed_string(query) - count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT - offset = offset if offset else 0 + count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT + offset = offset or 0 pipeline = [ { diff --git a/griptape/drivers/vector/opensearch_vector_store_driver.py b/griptape/drivers/vector/opensearch_vector_store_driver.py index e701e4084..267b549b7 100644 --- a/griptape/drivers/vector/opensearch_vector_store_driver.py +++ b/griptape/drivers/vector/opensearch_vector_store_driver.py @@ -60,7 +60,7 @@ def upsert_vector( If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ - vector_id = vector_id if vector_id else utils.str_to_hash(str(vector)) + vector_id = vector_id or utils.str_to_hash(str(vector)) doc = {"vector": vector, "namespace": namespace, "metadata": meta} doc.update(kwargs) response = self.client.index(index=self.index_name, id=vector_id, body=doc) @@ -138,7 +138,7 @@ def query( Returns: A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. """ - count = count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT + count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT vector = self.embedding_driver.embed_string(query) # Base k-NN query query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}} diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index 028ddebd6..a3a132ab3 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -36,7 +36,7 @@ def upsert_vector( meta: Optional[dict] = None, **kwargs, ) -> str: - vector_id = vector_id if vector_id else str_to_hash(str(vector)) + vector_id = vector_id or str_to_hash(str(vector)) params: dict[str, Any] = {"namespace": namespace} | kwargs @@ -95,7 +95,7 @@ def query( vector = self.embedding_driver.embed_string(query) params = { - "top_k": count if count else BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, + "top_k": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "namespace": namespace, "include_values": include_vectors, "include_metadata": include_metadata, diff --git a/griptape/drivers/vector/redis_vector_store_driver.py b/griptape/drivers/vector/redis_vector_store_driver.py index 06aa853c6..0abf2c985 100644 --- a/griptape/drivers/vector/redis_vector_store_driver.py +++ b/griptape/drivers/vector/redis_vector_store_driver.py @@ -60,7 +60,7 @@ def upsert_vector( If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ - vector_id = vector_id if vector_id else str_to_hash(str(vector)) + vector_id = vector_id or str_to_hash(str(vector)) key = self._generate_key(vector_id, namespace) bytes_vector = json.dumps(vector).encode("utf-8") diff --git a/griptape/loaders/email_loader.py b/griptape/loaders/email_loader.py index a54d5a063..82f34bd8a 100644 --- a/griptape/loaders/email_loader.py +++ b/griptape/loaders/email_loader.py @@ -56,7 +56,7 @@ def load(self, source: EmailQuery, *args, **kwargs) -> ListArtifact | ErrorArtif top_n = max(0, messages_count - max_count) if max_count else 0 for i in range(messages_count, top_n, -1): - result, data = client.fetch(str(i), "(RFC822)") + _result, data = client.fetch(str(i), "(RFC822)") if data is None or not data or data[0] is None: continue diff --git a/griptape/mixins/media_artifact_file_output_mixin.py b/griptape/mixins/media_artifact_file_output_mixin.py index 4097960bd..9b9f34911 100644 --- a/griptape/mixins/media_artifact_file_output_mixin.py +++ b/griptape/mixins/media_artifact_file_output_mixin.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from pathlib import Path from typing import TYPE_CHECKING, Optional from attrs import Attribute, define, field @@ -41,5 +42,4 @@ def _write_to_file(self, artifact: BlobArtifact) -> None: if os.path.dirname(outfile): os.makedirs(os.path.dirname(outfile), exist_ok=True) - with open(outfile, "wb") as f: - f.write(artifact.value) + Path(outfile).write_bytes(artifact.value) diff --git a/griptape/tasks/base_image_generation_task.py b/griptape/tasks/base_image_generation_task.py index 73b3cd42b..d32e8f142 100644 --- a/griptape/tasks/base_image_generation_task.py +++ b/griptape/tasks/base_image_generation_task.py @@ -2,6 +2,7 @@ import os from abc import ABC +from pathlib import Path from typing import TYPE_CHECKING from attrs import Attribute, define, field @@ -60,5 +61,4 @@ def all_negative_rulesets(self) -> list[Ruleset]: def _read_from_file(self, path: str) -> MediaArtifact: self.structure.logger.info("Reading image from %s", os.path.abspath(path)) - with open(path, "rb") as file: - return ImageLoader().load(file.read()) + return ImageLoader().load(Path(path).read_bytes()) diff --git a/griptape/tokenizers/openai_tokenizer.py b/griptape/tokenizers/openai_tokenizer.py index b839109fa..15220b4bd 100644 --- a/griptape/tokenizers/openai_tokenizer.py +++ b/griptape/tokenizers/openai_tokenizer.py @@ -64,7 +64,7 @@ def _default_max_input_tokens(self) -> int: tokens = next((v for k, v in self.MODEL_PREFIXES_TO_MAX_INPUT_TOKENS.items() if self.model.startswith(k)), None) offset = 0 if self.model in self.EMBEDDING_MODELS else self.TOKEN_OFFSET - return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - offset + return (tokens or self.DEFAULT_MAX_TOKENS) - offset def _default_max_output_tokens(self) -> int: tokens = next( @@ -84,7 +84,7 @@ def count_tokens(self, text: str | list[dict], model: Optional[str] = None) -> i https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb. """ if isinstance(text, list): - model = model if model else self.model + model = model or self.model try: encoding = tiktoken.encoding_for_model(model) diff --git a/griptape/tools/audio_transcription_client/tool.py b/griptape/tools/audio_transcription_client/tool.py index bf2dfc67f..62cd9e7a5 100644 --- a/griptape/tools/audio_transcription_client/tool.py +++ b/griptape/tools/audio_transcription_client/tool.py @@ -1,5 +1,6 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING, Any, cast from attrs import Factory, define, field @@ -31,8 +32,7 @@ class AudioTranscriptionClient(BaseTool): def transcribe_audio_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: audio_path = params["values"]["path"] - with open(audio_path, "rb") as f: - audio_artifact = self.audio_loader.load(f.read()) + audio_artifact = self.audio_loader.load(Path(audio_path).read_bytes()) return self.engine.run(audio_artifact) diff --git a/griptape/tools/base_tool.py b/griptape/tools/base_tool.py index 3fb6af26b..bfb754bc0 100644 --- a/griptape/tools/base_tool.py +++ b/griptape/tools/base_tool.py @@ -183,7 +183,7 @@ def tool_dir(self) -> str: return os.path.dirname(os.path.abspath(class_file)) def install_dependencies(self, env: Optional[dict[str, str]] = None) -> None: - env = env if env else {} + env = env or {} command = [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"] diff --git a/griptape/tools/computer/tool.py b/griptape/tools/computer/tool.py index 9ab290137..4e996c63c 100644 --- a/griptape/tools/computer/tool.py +++ b/griptape/tools/computer/tool.py @@ -142,8 +142,7 @@ def execute_code_in_container(self, filename: str, code: str) -> BaseArtifact: local_file_path = os.path.join(local_workdir, filename) try: - with open(local_file_path, "w") as f: - f.write(code) + Path(local_file_path).write_text(code) return self.execute_command_in_container(f"python {container_file_path}") except Exception as e: @@ -188,7 +187,7 @@ def build_image(self, tool: BaseTool) -> None: def dependencies(self) -> list[str]: with open(self.requirements_txt_path) as file: - return [line.strip() for line in file.readlines()] + return [line.strip() for line in file] def __del__(self) -> None: if self._tempdir: diff --git a/griptape/tools/image_query_client/tool.py b/griptape/tools/image_query_client/tool.py index 1b8cea534..a10929b13 100644 --- a/griptape/tools/image_query_client/tool.py +++ b/griptape/tools/image_query_client/tool.py @@ -1,5 +1,6 @@ from __future__ import annotations +from pathlib import Path from typing import TYPE_CHECKING, Any, cast from attrs import Factory, define, field @@ -40,8 +41,7 @@ def query_image_from_disk(self, params: dict) -> TextArtifact | ErrorArtifact: image_artifacts = [] for image_path in image_paths: - with open(image_path, "rb") as f: - image_artifacts.append(self.image_loader.load(f.read())) + image_artifacts.append(self.image_loader.load(Path(image_path).read_bytes())) return self.image_query_engine.run(query, image_artifacts) diff --git a/griptape/utils/conversation.py b/griptape/utils/conversation.py index 25dd310e1..97318c426 100644 --- a/griptape/utils/conversation.py +++ b/griptape/utils/conversation.py @@ -16,8 +16,7 @@ def lines(self) -> list[str]: lines = [] for run in self.memory.runs: - lines.append(f"Q: {run.input}") - lines.append(f"A: {run.output}") + lines.extend((f"Q: {run.input}", f"A: {run.output}")) return lines diff --git a/griptape/utils/file_utils.py b/griptape/utils/file_utils.py index f730034d3..19c9f699c 100644 --- a/griptape/utils/file_utils.py +++ b/griptape/utils/file_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations from concurrent import futures +from pathlib import Path from typing import Optional import griptape.utils as utils @@ -15,8 +16,7 @@ def load_file(path: str) -> bytes: Returns: The content of the file. """ - with open(path, "rb") as f: - return f.read() + return Path(path).read_bytes() def load_files(paths: list[str], futures_executor: Optional[futures.ThreadPoolExecutor] = None) -> dict[str, bytes]: diff --git a/pyproject.toml b/pyproject.toml index 2b11746d8..98bbe4a03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,6 +255,7 @@ select = [ "TCH", # flake8-type-checking "ERA", # eradicate "PGH", # pygrep-hooks + "FURB", # refurb ] ignore = [ "UP007", # non-pep604-annotation @@ -278,6 +279,7 @@ ignore = [ "ANN401", # any-type "PT011", # pytest-raises-too-broad ] +preview = true [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index 0d2827683..a7c244f09 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -17,8 +17,7 @@ def temp_dir(self): def write_file(path: str, content: bytes) -> None: full_path = os.path.join(temp_dir, path) os.makedirs(os.path.dirname(full_path), exist_ok=True) - with open(full_path, "wb") as f: - f.write(content) + Path(full_path).write_bytes(content) def mkdir(path: str) -> None: full_path = os.path.join(temp_dir, path) @@ -28,8 +27,7 @@ def copy_test_resources(resource_path: str) -> None: file_dir = os.path.dirname(__file__) full_path = os.path.join(file_dir, "../../../resources", resource_path) full_path = os.path.normpath(full_path) - with open(full_path, "rb") as source: - content = source.read() + content = Path(full_path).read_bytes() dest_path = os.path.join(temp_dir, "resources", resource_path) write_file(dest_path, content) diff --git a/tests/unit/drivers/web_scraper/test_trafilatura_web_scraper_driver.py b/tests/unit/drivers/web_scraper/test_trafilatura_web_scraper_driver.py index 53ddf4500..31d3016e9 100644 --- a/tests/unit/drivers/web_scraper/test_trafilatura_web_scraper_driver.py +++ b/tests/unit/drivers/web_scraper/test_trafilatura_web_scraper_driver.py @@ -11,7 +11,7 @@ def _mock_fetch_url(self, mocker): # characters to the body. mocker.patch( "trafilatura.fetch_url" - ).return_value = f'{"x"*243}foobar' + ).return_value = f'{"x" * 243}foobar' @pytest.fixture() def web_scraper(self): diff --git a/tests/unit/engines/summary/test_prompt_summary_engine.py b/tests/unit/engines/summary/test_prompt_summary_engine.py index e826a2b4d..4d9c65e03 100644 --- a/tests/unit/engines/summary/test_prompt_summary_engine.py +++ b/tests/unit/engines/summary/test_prompt_summary_engine.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import pytest @@ -38,7 +39,6 @@ def copy_test_resource(resource_path: str): file_dir = os.path.dirname(__file__) full_path = os.path.join(file_dir, "../../../resources", resource_path) full_path = os.path.normpath(full_path) - with open(full_path) as f: - return f.read() + return Path(full_path).read_text() assert engine.summarize_text(copy_test_resource("test.txt") * 50) diff --git a/tests/unit/loaders/conftest.py b/tests/unit/loaders/conftest.py index 494916be6..1f698738a 100644 --- a/tests/unit/loaders/conftest.py +++ b/tests/unit/loaders/conftest.py @@ -15,8 +15,7 @@ def create_source(resource_path: str) -> Path: @pytest.fixture() def bytes_from_resource_path(path_from_resource_path): def create_source(resource_path: str) -> bytes: - with open(path_from_resource_path(resource_path), "rb") as f: - return f.read() + return Path(path_from_resource_path(resource_path)).read_bytes() return create_source @@ -24,7 +23,6 @@ def create_source(resource_path: str) -> bytes: @pytest.fixture() def str_from_resource_path(path_from_resource_path): def test_csv_str(resource_path: str) -> str: - with open(path_from_resource_path(resource_path)) as f: - return f.read() + return Path(path_from_resource_path(resource_path)).read_text() return test_csv_str diff --git a/tests/unit/tools/test_transcription_client.py b/tests/unit/tools/test_transcription_client.py index 7768792d0..8b54e891b 100644 --- a/tests/unit/tools/test_transcription_client.py +++ b/tests/unit/tools/test_transcription_client.py @@ -18,6 +18,14 @@ def audio_loader(self) -> Mock: return loader + @pytest.fixture( + autouse=True, + ) + def mock_path(self, mocker) -> Mock: + mocker.patch("pathlib.Path.read_bytes", return_value=b"transcription") + + return mocker + def test_init_transcription_client(self, transcription_engine, audio_loader) -> None: assert AudioTranscriptionClient(engine=transcription_engine, audio_loader=audio_loader)