From db36fc35aaeec244e295adcbde15dc944aa61545 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 6 Nov 2024 13:41:06 -0800 Subject: [PATCH] fix loading multiple models Signed-off-by: Vibhu Jawa --- crossfit/backend/torch/hf/model.py | 12 ++++++------ crossfit/backend/torch/model.py | 8 +++----- tests/backend/pytorch_backend/test_torch_ops.py | 8 ++++---- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/crossfit/backend/torch/hf/model.py b/crossfit/backend/torch/hf/model.py index 85224cc..174d62d 100644 --- a/crossfit/backend/torch/hf/model.py +++ b/crossfit/backend/torch/hf/model.py @@ -81,14 +81,14 @@ def __init__( ) def load_on_worker(self, worker, device="cuda"): - worker.torch_model = self.load_model(device) - worker.cfg = self.load_cfg() + setattr(worker, f"torch_model_{self.path_or_name}", self.load_model(device)) + setattr(worker, f"cfg_{self.path_or_name}", self.load_cfg()) def unload_from_worker(self, worker): - if hasattr(worker, "torch_model"): - delattr(worker, "torch_model") - if hasattr(worker, "cfg"): - delattr(worker, "cfg") + if hasattr(worker, f"torch_model_{self.path_or_name}"): + delattr(worker, f"torch_model_{self.path_or_name}") + if hasattr(worker, f"cfg_{self.path_or_name}"): + delattr(worker, f"cfg_{self.path_or_name}") cleanup_torch_cache() def load_model(self, device="cuda"): diff --git a/crossfit/backend/torch/model.py b/crossfit/backend/torch/model.py index 0e5a3e9..eaeaebb 100644 --- a/crossfit/backend/torch/model.py +++ b/crossfit/backend/torch/model.py @@ -64,14 +64,12 @@ def unload_from_worker(self, worker): raise NotImplementedError() def call_on_worker(self, worker, *args, **kwargs): - return worker.torch_model(*args, **kwargs) + return getattr(worker, f"torch_model_{self.path_or_name}")(*args, **kwargs) def get_model(self, worker): - # TODO: We should not hard code the attribute name - # to torch_model. We should use the path_or_name_model - if not hasattr(worker, "torch_model"): + if not hasattr(worker, f"torch_model_{self.path_or_name}"): self.load_on_worker(worker) - return worker.torch_model + return getattr(worker, f"torch_model_{self.path_or_name}") def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int: raise NotImplementedError() diff --git a/tests/backend/pytorch_backend/test_torch_ops.py b/tests/backend/pytorch_backend/test_torch_ops.py index c732189..81a4cf5 100644 --- a/tests/backend/pytorch_backend/test_torch_ops.py +++ b/tests/backend/pytorch_backend/test_torch_ops.py @@ -39,13 +39,13 @@ def mock_worker(self): def test_unload_from_worker(self, model, mock_worker): model.load_on_worker(mock_worker) - assert hasattr(mock_worker, "torch_model") - assert hasattr(mock_worker, "cfg") + assert hasattr(mock_worker, f"torch_model_{model.path_or_name}") + assert hasattr(mock_worker, f"cfg_{model.path_or_name}") model.unload_from_worker(mock_worker) - assert not hasattr(mock_worker, "torch_model") - assert not hasattr(mock_worker, "cfg") + assert not hasattr(mock_worker, f"torch_model_{model.path_or_name}") + assert not hasattr(mock_worker, f"cfg_{model.path_or_name}") class DummyModelWithDictOutput(torch.nn.Module):