Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into stateful-dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 authored Jul 5, 2024
2 parents f273abc + 2471eac commit 8a46eb6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

setup(
name="accelerate",
version="0.32.0.dev0",
version="0.33.0.dev0",
description="Accelerate",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.32.0.dev0"
__version__ = "0.33.0.dev0"

from .accelerator import Accelerator
from .big_modeling import (
Expand Down
15 changes: 10 additions & 5 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@
from .imports import is_mlu_available, is_mps_available, is_npu_available, is_xpu_available


def clear_device_cache():
gc.collect()
def clear_device_cache(garbage_collection=False):
"""
Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that
this is a *considerable* slowdown and should be used sparingly.
"""
if garbage_collection:
gc.collect()

if is_xpu_available():
torch.xpu.empty_cache()
Expand Down Expand Up @@ -67,7 +72,7 @@ def release_memory(*objects):
objects = list(objects)
for i in range(len(objects)):
objects[i] = None
clear_device_cache()
clear_device_cache(garbage_collection=True)
return objects


Expand Down Expand Up @@ -123,7 +128,7 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i

def decorator(*args, **kwargs):
nonlocal batch_size
clear_device_cache()
clear_device_cache(garbage_collection=True)
params = list(inspect.signature(function).parameters.keys())
# Guard against user error
if len(params) < (len(args) + 1):
Expand All @@ -139,7 +144,7 @@ def decorator(*args, **kwargs):
return function(batch_size, *args, **kwargs)
except Exception as e:
if should_reduce_batch_size(e):
clear_device_cache()
clear_device_cache(garbage_collection=True)
batch_size //= 2
else:
raise
Expand Down
5 changes: 1 addition & 4 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,10 +748,7 @@ def test_load_state_dict(self):

for param, device in device_map.items():
device = device if device != "disk" else "cpu"
expected_device = (
torch.device(f"{torch_device}:{device}") if isinstance(device, int) else torch.device(device)
)
assert loaded_state_dict[param].device == expected_device
assert loaded_state_dict[param].device == torch.device(device)

def test_convert_file_size(self):
result = convert_file_size_to_int("0MB")
Expand Down

0 comments on commit 8a46eb6

Please sign in to comment.