diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 57294885546..3925abd43e6 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -725,7 +725,7 @@ def set_deepspeed_weakref(self): if ds_config["train_batch_size"] == "auto": del ds_config["train_batch_size"] - from transformers.deepspeed import HfDeepSpeedConfig + from transformers.integrations import HfDeepSpeedConfig self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa