You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The current way multi-slicing training is not robust or reliable on spot instances.
There have been some discussion inside CRFM and with GCP team on this topic. I create this issue to capture the main ideas and threads.
Main Objectives
Implement a more robust system that will provision multiple TPU slices and allocate workload on available slices, by constantly monitoring the status of slices.
Take a step further, to use Ray to coordinate multiple training jobs on single slice as well as a single training runs on multi slice
Challenges
For spot instances in TRC, we could not directly use GKE service. It is due to lack of ability to disable TPU billing via GKE. Therefore, many of the tricks and techniques implemented on GKE is not available to us, including Pathways,
Possible Ideas
Use Ray for scheduling
Allen Wang from Google proposed to use Ray to schedule and run workloads through slices. He put together a quick gist on how to run both single and multi-slice workloads via Ray (>= 2.10.0). This covers the job scheduling aspect and will work regardless if the cluster is provisioned directly on VMs or on GKE.
To mitigate potential race conditions, Allen also added placement groups to pre-reserve existing TPU pod slices (ray_tpu.py) and an example of how it can be used to run tasks (ray_tpu_task.py)
David's summarization:
We would spin up a ray head node as a job scheduler
tpu slices register with the head node (each worker runs ray start)
we launch a multislice tpu job by getting a set of named worker 0s, setting the env variables, then launching the real job
Use a host to coordinate work and communicate gradients
spin up individual slices as more or less atomic units.
They run their own SPMD job as levanter instances.
However, they also phone home to some coordinator machine.
coordinator assigns work units (batches) to slices, tells them how to communicate the gradients. JAX has an experimental “jax.lax.infeed” and “jax.lax.outfeed” for sending and receiving values from the host. The host receives the appropriate gradients from the device (using outfeed), communicates them to the other hosts (using a tree or something fancy), then sends the accumulated gradients via infeed.
The trick will be scheduling this to maximize throughput, since I don’t know how to tell XLA how long something will take.
To make things reproducible, you’ll have to be very careful, ensuring that you reduce batches in the same order even in the presence of slices dropping out. To do this, you will likely have to either use a lot of host memory and/or accept recomputing batches.
The text was updated successfully, but these errors were encountered:
I could reproduce Allen's script on single slice v4 TPU, but not on multi-slices. It should not be a blocker for now, if we are not prioritizing multi-slice training.
Now I think more of it, I realized that this is not the shortest path. I should instead take reference of Marin's existing Ray + TPU framework for launching data preparation jobs. It seems to be a more applicable guide.
The current way multi-slicing training is not robust or reliable on spot instances.
There have been some discussion inside CRFM and with GCP team on this topic. I create this issue to capture the main ideas and threads.
Main Objectives
Challenges
Possible Ideas
Use Ray for scheduling
Allen Wang from Google proposed to use Ray to schedule and run workloads through slices. He put together a quick gist on how to run both single and multi-slice workloads via Ray (>= 2.10.0). This covers the job scheduling aspect and will work regardless if the cluster is provisioned directly on VMs or on GKE.
To mitigate potential race conditions, Allen also added placement groups to pre-reserve existing TPU pod slices (ray_tpu.py) and an example of how it can be used to run tasks (ray_tpu_task.py)
David's summarization:
Use a host to coordinate work and communicate gradients
@dlwh 's idea:
The trick will be scheduling this to maximize throughput, since I don’t know how to tell XLA how long something will take.
To make things reproducible, you’ll have to be very careful, ensuring that you reduce batches in the same order even in the presence of slices dropping out. To do this, you will likely have to either use a lot of host memory and/or accept recomputing batches.
The text was updated successfully, but these errors were encountered: