diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index cc83840ba7..8848e95034 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -442,14 +442,7 @@ def load_model( if cfg.ddp and not load_in_8bit: model.to(f"cuda:{cfg.local_rank}") - if ( - torch.cuda.device_count() > 1 - and int(os.getenv("WORLD_SIZE", "1")) > 1 - and (cfg.load_in_4bit) - ): - # llama is PROBABLY model parallelizable, but the default isn't that it is - # so let's only set it for the 4bit, see - # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133 + if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: setattr(model, "is_parallelizable", True) setattr(model, "model_parallel", True)