From 24e2316ca7d84aa65621891613c01b4fe23b3e9a Mon Sep 17 00:00:00 2001 From: Rory Date: Fri, 15 Dec 2023 14:12:46 +0000 Subject: [PATCH] feat: scanner provides list of extensions that are not on valid list --- app/scanner.py | 29 ++++++++++++++++++++++++----- test/test_scanner.py | 25 ++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/app/scanner.py b/app/scanner.py index 201566b..c562519 100644 --- a/app/scanner.py +++ b/app/scanner.py @@ -1,3 +1,4 @@ +from functools import partial import os VALID_PHOTO_EXTENSIONS = ['.bmp', '.gif', '.jpg', '.jpeg', '.png', '.tif', '.tiff'] @@ -7,10 +8,28 @@ class Scanner: - def scan_directory(self, source_dir): + def valid_filepaths_in(self, source_dir): + return self.__scan_directory(source_dir, self.__ext_valid, self.__full_file_path) + + def invalid_extensions_in(self, source_dir): + return set(self.__scan_directory(source_dir, self.__ext_invalid, self.__extension_only)) + + def __ext_valid(self, file): + return self.__extension(file) in VALID_EXTENSIONS + + def __ext_invalid(self, file): + return self.__extension(file) not in VALID_EXTENSIONS + + def __extension(self, file): + return os.path.splitext(file)[1] + + def __full_file_path(self, root, file): + return os.path.join(root, file) + + def __extension_only(self, _, file): + return self.__extension(file) + + def __scan_directory(self, source_dir, extension_filter, path_constructor): file_tree = os.walk(source_dir) for (root, _, files) in file_tree: - for file in files: - _, extension = os.path.splitext(file) - if extension in VALID_EXTENSIONS: - yield os.path.join(root, file) + yield from filter(extension_filter, map(partial(path_constructor, root), files)) diff --git a/test/test_scanner.py b/test/test_scanner.py index 7acd9ae..45f31c5 100644 --- a/test/test_scanner.py +++ b/test/test_scanner.py @@ -13,7 +13,7 @@ def teardown(): def test_scanner_discovers_files_to_be_copied(): create_files_with_desired_extensions() - files_to_copy = scanner.scan_directory(source_directory) + files_to_copy = scanner.valid_filepaths_in(source_directory) assert sorted(list(files_to_copy)) == valid_source_filepaths() @@ -22,7 +22,7 @@ def test_scanner_ignores_files_without_desired_extensions(): create_files_with_desired_extensions() create_files_without_desired_extensions() - files_to_copy = scanner.scan_directory(source_directory) + files_to_copy = scanner.valid_filepaths_in(source_directory) assert sorted(list(files_to_copy)) == valid_source_filepaths() @@ -31,6 +31,25 @@ def test_copies_file_when_provided_source_path_does_not_have_trailing_backslash( filename = 'a_file.jpeg' source_filepath = create_file(source_directory, filename) source_directory_without_trailing_backslash = source_directory[0:-1] - files_to_copy = scanner.scan_directory(source_directory_without_trailing_backslash) + files_to_copy = scanner.valid_filepaths_in(source_directory_without_trailing_backslash) assert list(files_to_copy) == [source_filepath] + + +def test_provides_a_set_of_invalid_file_extensions(): + create_file(source_directory, 'a_file.non') + create_file(source_directory, 'a_file_2.non') + result = scanner.invalid_extensions_in(source_directory) + extension = '.non' + + assert list(result) == [extension] + assert len(list(result)) == 1 + + +def test_provides_multiple_invalid_extensions(): + create_file(source_directory, 'a_file.non') + create_file(source_directory, 'a_file_2.hlp') + result = scanner.invalid_extensions_in(source_directory) + extensions = sorted(['.non', '.hlp']) + + assert sorted(list(result)) == extensions