Skip to content

Commit

Permalink
is_safetensors_compatible fix (#9741)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
DN6 authored Oct 22, 2024
1 parent 0d9d98f commit 76c00c7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components.setdefault(component, [])
components[component].append(component_filename)

# If there are no component folders check the main directory for safetensors files
if not components:
return any(".safetensors" in filename for filename in filenames)

# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
Expand Down
12 changes: 12 additions & 0 deletions tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,18 @@ def test_diffusers_is_compatible_only_variants(self):
]
self.assertTrue(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_no_components(self):
filenames = [
"diffusion_pytorch_model.bin",
]
self.assertFalse(is_safetensors_compatible(filenames))

def test_diffusers_is_compatible_no_components_only_variants(self):
filenames = [
"diffusion_pytorch_model.fp16.bin",
]
self.assertFalse(is_safetensors_compatible(filenames))


class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):
Expand Down

0 comments on commit 76c00c7

Please sign in to comment.