Skip to content

Commit

Permalink
Add documentation for JAXJob
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <[email protected]>
  • Loading branch information
sandipanpanda committed Sep 24, 2024
1 parent b83d113 commit 788422a
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 4 deletions.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 3 additions & 2 deletions content/en/docs/components/training/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ weight = 10

The Training Operator is a Kubernetes-native project for fine-tuning and scalable
distributed training of machine learning (ML) models created with different ML frameworks such as
PyTorch, TensorFlow, XGBoost, and others.
PyTorch, TensorFlow, XGBoost, [JAX](https://jax.readthedocs.io/en/latest/), and others.

You can integrate other ML libraries such as [HuggingFace](https://huggingface.co),
[DeepSpeed](https://github.com/microsoft/DeepSpeed), or [Megatron-LM](https://github.com/NVIDIA/Megatron-LM)
Expand All @@ -26,7 +26,7 @@ supports running Message Passing Interface (MPI) on Kubernetes which is heavily
The Training Operator implements the V1 API version of MPI Operator. For the MPI Operator V2 version,
please follow [this guide](/docs/components/training/user-guides/mpi/) to install MPI Operator V2.

<img src="/docs/components/training/images/training-operator-overview.drawio.png"
<img src="/docs/components/training/images/training-operator-overview.drawio.svg"
alt="Training Operator Overview"
class="mt-3 mb-3">

Expand Down Expand Up @@ -70,6 +70,7 @@ for each ML framework:
| XGBoost | [XGBoostJob](/docs/components/training/user-guides/xgboost/) |
| MPI | [MPIJob](/docs/components/training/user-guides/mpi/) |
| PaddlePaddle | [PaddleJob](/docs/components/training/user-guides/paddle/) |
| JAX | [JAXJob](/docs/components/training/user-guides/jax/) |

## Next steps

Expand Down
101 changes: 101 additions & 0 deletions content/en/docs/components/training/user-guides/jax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
+++
title = "JAX Training (JAXJob)"
description = "Using JAXJob to train a model with JAX"
weight = 60
+++

This page describes `JAXJob` for training a machine learning model with [JAX](https://jax.readthedocs.io/en/latest/).

The `JAXJob` is a Kubernetes
[custom resource](https://kubernetes.io/docs/concepts/extend-kubernetes/api-extension/custom-resources/)
to run JAX training jobs on Kubernetes. The Kubeflow implementation of
the `JAXJob` is in the [`training-operator`](https://github.com/kubeflow/training-operator).

The current custom resource for JAX has been tested to run multiple processes on CPUs using [gloo](https://github.com/facebookincubator/gloo) for communication between CPUs.

## Creating a JAX training job

You can create a training job by defining a `JAXJob` config file. See the manifests for the [simple JAXJob example](https://github.com/kubeflow/training-operator/blob/master/examples/jax/cpu-demo/demo.yaml).
You may change the config file based on your requirements.

Deploy the `JAXJob` resource to start training:

```
kubectl create -f https://raw.githubusercontent.com/kubeflow/training-operator/refs/heads/master/examples/jax/cpu-demo/demo.yaml
```

You should now be able to see the created pods matching the specified number of replicas.

```
kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-simple
```

Training takes 5-10 minutes on a CPU cluster. Logs can be inspected to see its training progress.

```
PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow)
kubectl logs -f ${PODNAME} -n kubeflow
```

## Monitoring a JAXJob

```
kubectl get -o yaml jaxjobs jaxjob-simple -n kubeflow
```

See the status section to monitor the job status. Here is sample output when the job is successfully completed.

```yaml
apiVersion: kubeflow.org/v1
kind: JAXJob
metadata:
annotations:
kubectl.kubernetes.io/last-applied-configuration: |
{"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-simple","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"command":["python3","train.py"],"image":"docker.io/kubeflow/jaxjob-simple:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}}
creationTimestamp: "2024-09-22T20:07:59Z"
generation: 1
name: jaxjob-simple
namespace: kubeflow
resourceVersion: "1972"
uid: eb20c874-44fc-459b-b9a8-09f5c3ff46d3
spec:
jaxReplicaSpecs:
Worker:
replicas: 2
restartPolicy: OnFailure
template:
spec:
containers:
- command:
- python3
- train.py
image: docker.io/kubeflow/jaxjob-simple:latest
imagePullPolicy: Always
name: jax
status:
completionTime: "2024-09-22T20:11:34Z"
conditions:
- lastTransitionTime: "2024-09-22T20:07:59Z"
lastUpdateTime: "2024-09-22T20:07:59Z"
message: JAXJob jaxjob-simple is created.
reason: JAXJobCreated
status: "True"
type: Created
- lastTransitionTime: "2024-09-22T20:11:28Z"
lastUpdateTime: "2024-09-22T20:11:28Z"
message: JAXJob kubeflow/jaxjob-simple is running.
reason: JAXJobRunning
status: "False"
type: Running
- lastTransitionTime: "2024-09-22T20:11:34Z"
lastUpdateTime: "2024-09-22T20:11:34Z"
message: JAXJob kubeflow/jaxjob-simple successfully completed.
reason: JAXJobSucceeded
status: "True"
type: Succeeded
replicaStatuses:
Worker:
selector: training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker
succeeded: 2
startTime: "2024-09-22T20:07:59Z"
```
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
+++
title = "Job Scheduling"
description = "How to schedule a job with gang-scheduling"
weight = 60
weight = 70
+++

This guide describes how to use [Kueue](https://kueue.sigs.k8s.io/),
Expand Down
2 changes: 1 addition & 1 deletion content/en/docs/components/training/user-guides/mpi.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
+++
title = "MPI Training (MPIJob)"
description = "Instructions for using MPI for training"
weight = 60
weight = 70
+++

{{% beta-status
Expand Down

0 comments on commit 788422a

Please sign in to comment.