Skip to content

Commit

Permalink
fix model parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Nov 2, 2023
1 parent 10388a8 commit 9bae072
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,14 +442,7 @@ def load_model(
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")

if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
and (cfg.load_in_4bit)
):
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True)

Expand Down

0 comments on commit 9bae072

Please sign in to comment.