Skip to content

Commit

Permalink
Load relative files
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Apr 29, 2024
1 parent 86661a4 commit bf808e1
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions tests/unit/utils/test_file_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from griptape.loaders import TextLoader
from griptape import utils
from concurrent import futures
Expand All @@ -8,39 +9,45 @@

class TestFileUtils:
def test_load_file(self):
file = utils.load_file("tests/resources/foobar-many.txt")
dirname = os.path.dirname(__file__)
file = utils.load_file(os.path.join(dirname, "../../resources/foobar-many.txt"))

assert file.decode("utf-8").startswith("foobar foobar foobar")
assert len(file.decode("utf-8")) == 4563

def test_load_files(self):
sources = ["tests/resources/foobar-many.txt", "tests/resources/foobar-many.txt", "tests/resources/small.png"]
dirname = os.path.dirname(__file__)
sources = ["resources/foobar-many.txt", "resources/foobar-many.txt", "resources/small.png"]
sources = [os.path.join(dirname, "../../", source) for source in sources]
files = utils.load_files(sources, futures_executor=futures.ThreadPoolExecutor(max_workers=1))
assert len(files) == 2

test_file = files[utils.str_to_hash("tests/resources/foobar-many.txt")]
test_file = files[utils.str_to_hash(sources[0])]
assert len(test_file) == 4563
assert test_file.decode("utf-8").startswith("foobar foobar foobar")

small_file = files[utils.str_to_hash("tests/resources/small.png")]
small_file = files[utils.str_to_hash(sources[2])]
assert len(small_file) == 97
assert small_file[:8] == b"\x89PNG\r\n\x1a\n"

def test_load_file_with_loader(self):
file = utils.load_file("tests/resources/foobar-many.txt")
dirname = os.path.dirname(__file__)
file = utils.load_file(os.path.join(dirname, "../../", "resources/foobar-many.txt"))
artifacts = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver()).load(file)

assert len(artifacts) == 39
assert isinstance(artifacts, list)
assert artifacts[0].value.startswith("foobar foobar foobar")

def test_load_files_with_loader(self):
sources = ["tests/resources/foobar-many.txt"]
dirname = os.path.dirname(__file__)
sources = ["resources/foobar-many.txt"]
sources = [os.path.join(dirname, "../../", source) for source in sources]
files = utils.load_files(sources)
loader = TextLoader(max_tokens=MAX_TOKENS, embedding_driver=MockEmbeddingDriver())
collection = loader.load_collection(list(files.values()))

test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash("tests/resources/foobar-many.txt")])]
test_file_artifacts = collection[loader.to_key(files[utils.str_to_hash(sources[0])])]
assert len(test_file_artifacts) == 39
assert isinstance(test_file_artifacts, list)
assert test_file_artifacts[0].value.startswith("foobar foobar foobar")

0 comments on commit bf808e1

Please sign in to comment.