From bf808e1fb3c5dac15b963c33c77f9989bb7d8313 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 29 Apr 2024 10:25:08 -0700 Subject: [PATCH] Load relative files --- tests/unit/utils/test_file_utils.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/unit/utils/test_file_utils.py b/tests/unit/utils/test_file_utils.py index ba1e845fa..dbcf1044b 100644 --- a/tests/unit/utils/test_file_utils.py +++ b/tests/unit/utils/test_file_utils.py @@ -1,3 +1,4 @@ +import os from griptape.loaders import TextLoader from griptape import utils from concurrent import futures @@ -8,26 +9,30 @@ class TestFileUtils: def test_load_file(self): - file = utils.load_file("tests/resources/foobar-many.txt") + dirname = os.path.dirname(__file__) + file = utils.load_file(os.path.join(dirname, "../../resources/foobar-many.txt")) assert file.decode("utf-8").startswith("foobar foobar foobar") assert len(file.decode("utf-8")) == 4563 def test_load_files(self): - sources = ["tests/resources/foobar-many.txt", "tests/resources/foobar-many.txt", "tests/resources/small.png"] + dirname = os.path.dirname(__file__) + sources = ["resources/foobar-many.txt", "resources/foobar-many.txt", "resources/small.png"] + sources = [os.path.join(dirname, "../../", source) for source in sources] files = utils.load_files(sources, futures_executor=futures.ThreadPoolExecutor(max_workers=1)) assert len(files) == 2 - test_file = files[utils.str_to_hash("tests/resources/foobar-many.txt")] + test_file = files[utils.str_to_hash(sources[0])] assert len(test_file) == 4563 assert test_file.decode("utf-8").startswith("foobar foobar foobar") - small_file = files[utils.str_to_hash("tests/resources/small.png")] + small_file = files[utils.str_to_hash(sources[2])] assert len(small_file) == 97 assert small_file[:8] == b"\x89PNG\r\n\x1a\n" def test_load_file_with_loader(self): - file = utils.load_file("tests/resources/foobar-many.txt") + dirname = os.path.dirname(__file__) + file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt")) artifacts = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()).load(file) assert len(artifacts) == 39 @@ -35,12 +40,14 @@ def test_load_file_with_loader(self): assert artifacts[0].value.startswith("foobar foobar foobar") def test_load_files_with_loader(self): - sources = ["tests/resources/foobar-many.txt"] + dirname = os.path.dirname(__file__) + sources = ["resources/foobar-many.txt"] + sources = [os.path.join(dirname, "../../", source) for source in sources] files = utils.load_files(sources) loader = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()) collection = loader.load_collection(list(files.values())) - test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash("tests/resources/foobar-many.txt")])] + test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])] assert len(test_file_artifacts) == 39 assert isinstance(test_file_artifacts, list) assert test_file_artifacts[0].value.startswith("foobar foobar foobar")