From dc507cdd38abdf4eafe26b2ae37143862345804e Mon Sep 17 00:00:00 2001 From: "ryoji.nagata" Date: Fri, 15 Nov 2024 09:43:36 +0900 Subject: [PATCH] For some reason, the part I didn't fix got fixed, so I put it back in. --- sentence_transformers/models/Transformer.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index 9076335ed..aeffd0ce9 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -294,12 +294,12 @@ def _backend_should_export( """ export = model_args.pop("export", None) - if export is not None: + if export: return export, model_args file_name = model_args.get("file_name", target_file_name) subfolder = model_args.get("subfolder", None) - primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else file_name + primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else Path(file_name).as_posix() secondary_full_path = ( Path(subfolder, self.backend, file_name).as_posix() if subfolder @@ -322,10 +322,10 @@ def _backend_should_export( # First check if the expected file exists in the root of the model directory # If it doesn't, check if it exists in the backend subfolder. # If it does, set the subfolder to include the backend - export = primary_full_path not in model_file_names - if export and "subfolder" not in model_args: - export = secondary_full_path not in model_file_names - if not export: + model_found = primary_full_path in model_file_names + if not model_found and "subfolder" not in model_args: + model_found = secondary_full_path in model_file_names + if model_found: if len(model_file_names) > 1 and "file_name" not in model_args: logger.warning( f"Multiple {backend_name} files found in {load_path.as_posix()!r}: {model_file_names}, defaulting to {secondary_full_path!r}. " @@ -333,6 +333,8 @@ def _backend_should_export( ) model_args["subfolder"] = self.backend model_args["file_name"] = file_name + if export is None: + export = not model_found # If the file_name contains subfolders, set it as the subfolder instead file_name_parts = Path(file_name).parts