diff --git a/share/example/feed_forward.py b/share/example/feed_forward.py index 0c33b7a..c57f6ae 100644 --- a/share/example/feed_forward.py +++ b/share/example/feed_forward.py @@ -46,8 +46,8 @@ class AllowEmptyClassImageFolder(ImageFolder): ''' def find_classes(self, directory): with os.scandir(directory) as scanit: - class_info = sorted((entry.name, entry.stat().st_size) for entry in scanit if entry.is_dir()) - class_to_idx = {class_name: index for index, (class_name, st_size) in enumerate(class_info) if st_size} + class_info = sorted((entry.name, len(list(os.scandir(entry.path)))) for entry in scanit if entry.is_dir()) + class_to_idx = {class_name: index for index, (class_name, n_members) in enumerate(class_info) if n_members} if not class_to_idx: raise FileNotFoundError(f'No non-empty classes found in \'{directory}\'.') return list(class_to_idx), class_to_idx