diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 567c70499e9..0f3877f2e81 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -66,9 +66,9 @@ def get_backend(): elif is_cuda_available(): return "cuda", torch.cuda.device_count(), torch.cuda.memory_allocated elif is_mps_available(min_version="2.0"): - return "mps", 1, torch.mps.current_allocated_memory() + return "mps", 1, torch.mps.current_allocated_memory elif is_mps_available(): - return "mps", 1, 0 + return "mps", 1, lambda: 0 elif is_mlu_available(): return "mlu", torch.mlu.device_count(), torch.mlu.memory_allocated elif is_npu_available(): @@ -76,7 +76,7 @@ def get_backend(): elif is_xpu_available(): return "xpu", torch.xpu.device_count(), torch.xpu.memory_allocated else: - return "cpu", 1, 0 + return "cpu", 1, lambda: 0 torch_device, device_count, memory_allocated_func = get_backend() diff --git a/src/accelerate/utils/memory.py b/src/accelerate/utils/memory.py index c344e6eb616..97af2726d13 100644 --- a/src/accelerate/utils/memory.py +++ b/src/accelerate/utils/memory.py @@ -26,6 +26,21 @@ from .imports import is_mlu_available, is_mps_available, is_npu_available, is_xpu_available +def clear_device_cache(): + gc.collect() + + if is_xpu_available(): + torch.xpu.empty_cache() + elif is_mlu_available(): + torch.mlu.empty_cache() + elif is_npu_available(): + torch.npu.empty_cache() + elif is_mps_available(min_version="2.0"): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + + def release_memory(*objects): """ Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`. @@ -52,17 +67,7 @@ def release_memory(*objects): objects = list(objects) for i in range(len(objects)): objects[i] = None - gc.collect() - if is_xpu_available(): - torch.xpu.empty_cache() - elif is_mlu_available(): - torch.mlu.empty_cache() - elif is_npu_available(): - torch.npu.empty_cache() - elif is_mps_available(min_version="2.0"): - torch.mps.empty_cache() - else: - torch.cuda.empty_cache() + clear_device_cache() return objects @@ -118,15 +123,7 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i def decorator(*args, **kwargs): nonlocal batch_size - gc.collect() - if is_xpu_available(): - torch.xpu.empty_cache() - elif is_mlu_available(): - torch.mlu.empty_cache() - elif is_npu_available(): - torch.npu.empty_cache() - else: - torch.cuda.empty_cache() + clear_device_cache() params = list(inspect.signature(function).parameters.keys()) # Guard against user error if len(params) < (len(args) + 1): @@ -142,15 +139,7 @@ def decorator(*args, **kwargs): return function(batch_size, *args, **kwargs) except Exception as e: if should_reduce_batch_size(e): - gc.collect() - if is_xpu_available(): - torch.xpu.empty_cache() - elif is_mlu_available(): - torch.mlu.empty_cache() - elif is_npu_available(): - torch.npu.empty_cache() - else: - torch.cuda.empty_cache() + clear_device_cache() batch_size //= 2 else: raise diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index b57b476df41..8ab79fab036 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -39,6 +39,7 @@ is_torch_xla_available, is_xpu_available, ) +from .memory import clear_device_cache from .offload import load_offloaded_weight, offload_weight, save_offload_index from .tqdm import is_tqdm_available, tqdm from .versions import compare_versions @@ -456,14 +457,7 @@ def set_module_tensor_to_device( module.weight = module.weight.cuda(device_index) # clean pre and post foward hook if device != "cpu": - if is_npu_available(): - torch.npu.empty_cache() - elif is_mlu_available(): - torch.mlu.empty_cache() - elif is_xpu_available(): - torch.xpu.empty_cache() - else: - torch.cuda.empty_cache() + clear_device_cache() # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in # order to avoid duplicating memory, see above.