Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloader question upon restart #58

Open
cjolivier01 opened this issue Jan 7, 2025 · 6 comments
Open

Dataloader question upon restart #58

cjolivier01 opened this issue Jan 7, 2025 · 6 comments
Labels
data Related to dataloading enhancement New feature or request question Further information is requested

Comments

@cjolivier01
Copy link

Upon torchelastic restart, let's say with train_ddp.py, I haven't been able to find where, upon restart, the dataloader knows from where to resume, or whether it just starts from the "beginning" assuming that the randomness of the sampler will not duplicate its samples? I expect I am missing something, right?

@d4l3k
Copy link
Member

d4l3k commented Jan 8, 2025

@cjolivier01 this is something we want to improve by automatically tracking the dataloader step and fast forwarding as needed. It's feasible to do in torchft but we haven't gotten around to implementing it yet

The current recommended approach is to checkpoint the dataloader using torchdata's StatefulDataloader frequently. To avoid replaying any data you would need to checkpoint it on every step. https://pytorch.org/data/beta/torchdata.stateful_dataloader.html

To minimize overhead from that you could write a custom checkpoint implementation that calls dataloader .state_dict() periodically (say every 10 steps) and then only save the offset to disk (i.e. 5 steps). When reloading you would restore the StatefulDataloader from the checkpoint and then call next(iter) on it 5 times.

If you're interested in contributing better dataloader management I'm happy to set up some time to chat, otherwise, we'll get to it at some point :)

@d4l3k d4l3k added question Further information is requested data Related to dataloading enhancement New feature or request labels Jan 8, 2025
@d4l3k
Copy link
Member

d4l3k commented Jan 8, 2025

also see #37

@cjolivier01
Copy link
Author

Thank you for the reply! At the moment, I am still trying to figure out how all the machinery works together. Still inm the torchelastic realm (which I have not used before), am I correct in assuming that no state is currently saved or restored wrt StatefulDataloader when running the train_ddp.py? Since there is no active checkpointing (right?) in train_ddp.py (although I do see the inline load_state and set_state that would seem to handle that as well as the external CheckpointServer test), that's not happenning, right? Just not sure if I am running it incorrectly, trying to hook it all up and induce faults, etc.

@d4l3k
Copy link
Member

d4l3k commented Jan 9, 2025

@cjolivier01 the example train_ddp.py doesn't do any persistent checkpointing of model state or dataloader -- there is a TODO where it should occur but we don't have it. For completeness we should probably add checkpointing or at least a dummy hook.

The load_state_dict and state_dict methods are used for live recovery (i.e. transfer from a healthy replica to a recovering replica) and does not create a persistent checkpoint of the dataloader, model or optimizer.

You likely want to add checkpoint logic to the end of the train step where the TODO is. Persistent checkpoints should probably be say every 100-1000 steps but dataloader you likely need to checkpoint on every step for now if you don't want to retrain on the same examples.

@cjolivier01
Copy link
Author

I am trying to make torchft kick in, so I run two processes, as if they're on different nodes (they're on the same node, only --node-rank differs):

TORCHFT_MANAGER_PORT=29512 \
    TORCHFT_LIGHTHOUSE="http://localhost:29510" \
    torchrun \
    --master_port=29501 \
    --nnodes=1:2  \
    --nproc_per_node=1 \
    --max-restarts=3 \
    --rdzv-id=asdsddsded \
    --rdzv-backend=c10d \
    --rdzv-endpoint=localhost \
    --node_rank=0 \
    ./train_ddp.py

..and

TORCHFT_MANAGER_PORT=29512 \
    TORCHFT_LIGHTHOUSE="http://localhost:29510" \
    torchrun \
    --master_port=29501 \
    --nnodes=1:2  \
    --nproc_per_node=1 \
    --max-restarts=3 \
    --rdzv-id=asdsddsded \
    --rdzv-backend=c10d \
    --rdzv-endpoint=localhost \
    --node_rank=1 \
    ./train_ddp.py

...and the latter, I have it exit during a training step before optimizer.step(), then what I get on rank 0 is a RumtimeError (timeout) error in Manager's self._client.should_commit() and this throws the process out into exit. Is this the expecte dbehavior? does not recover. I must be doing something wrong, correct? is there a granularity I am not understanding? (i.e. a full replica group is expected to fail?)

@d4l3k
Copy link
Member

d4l3k commented Jan 10, 2025

@cjolivier01 you want the replica groups to not be part of the same torchelastic instance. These are the commands I use to run locally:

torchft_lighthouse --min_replicas 2 --join_timeout_ms 1000

CUDA_VISIBLE_DEVICES=0 TORCHFT_MANAGER_PORT=29512 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29502 --nnodes 1 --nproc_per_node 1 --max-restarts 10 train_ddp.py

CUDA_VISIBLE_DEVICES=1 TORCHFT_MANAGER_PORT=29513 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 --max-restarts 10 train_ddp.py

The elastic and manager ports should be different between replica groups.

Once #67 lands you won't need to specify the manager port at all

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data Related to dataloading enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants