Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into stateful-dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
byi8220 authored Jul 22, 2024
2 parents 4739524 + 2308576 commit 4de9159
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 17 deletions.
9 changes: 0 additions & 9 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,6 @@ def __init__(self, cpu: bool = False, **kwargs):
)
from deepspeed import comm as dist

if is_xpu_available() and is_ccl_available():
os.environ["CCL_PROCESS_LAUNCHER"] = "none"
os.environ["CCL_LOCAL_SIZE"] = os.environ.get("LOCAL_WORLD_SIZE", "1")
os.environ["CCL_LOCAL_RANK"] = os.environ.get("LOCAL_RANK", "0")

if not dist.is_initialized():
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
# We need to flag to `use_deepspeed` to be True to override `distributed_type` later
Expand All @@ -221,10 +216,6 @@ def __init__(self, cpu: bool = False, **kwargs):
os.environ["WORLD_SIZE"] = str(dist_information.world_size)
os.environ["LOCAL_RANK"] = str(dist_information.local_rank)
os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size)
if self.backend == "ccl" and self.distributed_type == DistributedType.MULTI_XPU:
os.environ["CCL_PROCESS_LAUNCHER"] = "none"
os.environ["CCL_LOCAL_SIZE"] = os.environ["LOCAL_WORLD_SIZE"]
os.environ["CCL_LOCAL_RANK"] = os.environ["LOCAL_RANK"]
if not os.environ.get("MASTER_PORT", None):
os.environ["MASTER_PORT"] = "29500"
if (
Expand Down
18 changes: 11 additions & 7 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .memory import clear_device_cache
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm
from .versions import compare_versions
from .versions import compare_versions, is_torch_version


if is_npu_available(check_device=False):
Expand Down Expand Up @@ -163,6 +163,8 @@ def dtype_byte_size(dtype: torch.dtype):
return 1 / 2
elif dtype == CustomDtype.FP8:
return 1
elif is_torch_version(">=", "2.1.0") and dtype == torch.float8_e4m3fn:
return 1
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
Expand Down Expand Up @@ -361,10 +363,15 @@ def set_module_tensor_to_device(
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")

param = module._parameters[tensor_name] if tensor_name in module._parameters else None
param_cls = type(param)

if value is not None:
if old_value.shape != value.shape:
# We can expect mismatches when using bnb 4bit since Params4bit will reshape and pack the weights.
# In other cases, we want to make sure we're not loading checkpoints that do not match the config.
if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
raise ValueError(
f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this look incorrect.'
f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.'
)

if dtype is None:
Expand All @@ -373,9 +380,6 @@ def set_module_tensor_to_device(
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
value = value.to(dtype)

param = module._parameters[tensor_name] if tensor_name in module._parameters else None
param_cls = type(param)

device_quantization = None
with torch.no_grad():
# leave it on cpu first before moving them to cuda
Expand Down Expand Up @@ -419,7 +423,7 @@ def set_module_tensor_to_device(
elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device):
param_cls = type(module._parameters[tensor_name])
kwargs = module._parameters[tensor_name].__dict__
if param_cls.__name__ in ["Int8Params", "FP4Params"]:
if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
# downcast to fp16 if any - needed for 8bit serialization
new_value = new_value.to(torch.float16)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_set_module_tensor_checks_shape(self):
set_module_tensor_to_device(model, "linear1.weight", "cpu", value=tensor)
assert (
str(cm.exception)
== 'Trying to set a tensor of shape torch.Size([2, 2]) in "weight" (which has shape torch.Size([4, 3])), this look incorrect.'
== 'Trying to set a tensor of shape torch.Size([2, 2]) in "weight" (which has shape torch.Size([4, 3])), this looks incorrect.'
)

def test_named_tensors(self):
Expand Down

0 comments on commit 4de9159

Please sign in to comment.