Skip to content

Commit

Permalink
add something somewhat working
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 10, 2024
1 parent 75ede2a commit ee89d33
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Config(BaseConfig):
per_device_train_batch_size: int = 32
warmup_steps: int = 1000
total_steps: int = 88_000
sharding_strategy: str = "SHARD_GRAD_OP"
sharding_strategy: str = "FULL_SHARD"


def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader:
Expand All @@ -82,6 +82,10 @@ def get_model(config: Config) -> LlamaForCausalLM:
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)


def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]:
return [param.data.detach().clone().to("cuda") for param in model.parameters()]


def train(config: Config):
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
local_rank = int(os.environ["LOCAL_RANK"])
Expand Down Expand Up @@ -123,7 +127,11 @@ def train(config: Config):

# Setup optimizers
inner_optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.1, betas=(0.9, 0.95))
# outer_optimizer = torch.optim.SGD(model.parameters(), lr=config.diloco.outer_lr, momentum=0.9, nesterov=True)

cpu_model = get_offloaded_param(
model
) # todo: in case of sharded grap op we need to offload the cpu model only once per nodes

Check failure on line 133 in open_diloco/train_pure_fsdp.py

View workflow job for this annotation

GitHub Actions / codespell

grap ==> grep, grape
outer_optimizer = torch.optim.SGD(cpu_model, lr=config.diloco.outer_lr, momentum=0.9, nesterov=True)

scheduler = get_cosine_schedule_with_warmup(
inner_optimizer,
Expand Down Expand Up @@ -152,6 +160,7 @@ def train(config: Config):
with model.no_sync() if is_accumulating else nullcontext():
outputs = model(**batch)
loss = outputs.loss / gradient_accumulation_steps
loss.backward()
loss_batch += loss.detach()

model.clip_grad_norm_(1.0) # gradient clipping
Expand All @@ -166,8 +175,22 @@ def train(config: Config):

loss_batch = 0

for param in model.parameters(): # todo make this like hybrid shard is doing
dist.all_reduce(param, op=dist.ReduceOp.AVG, group=global_pg)
### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient:
## do the all reduce on cpu or on gpu
## do the outer optimizer step on cpu or on gpu

for param_offloaded, param in zip(
cpu_model, model.parameters()
): # There is only one big fat tensor in the param because of fsdp 1 bucket stuff
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg)

outer_optimizer.step()
outer_optimizer.zero_grad()

for param_offloaded, param in zip(cpu_model, model.parameters()):
param.data = param_offloaded.data.to("cuda")

outer_step += 1

Expand Down

0 comments on commit ee89d33

Please sign in to comment.