Skip to content

Commit

Permalink
- add readme on how to install vllm for online inference during training
Browse files Browse the repository at this point in the history
  • Loading branch information
dmahan93 committed Sep 28, 2024
1 parent cc787b0 commit 3adf920
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
56 changes: 56 additions & 0 deletions post-training/OnlineTraining.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Online Training

## Prerequisites
Want to use REINFORCE to train your model? First you'll need to build a custom vllm package.

[synth-vllm](https://github.com/SynthLabsAI/synth-vllm) is a fork of [vllm](https://github.com/vllm-project/vllm)
that has been modified to support using the weights in NeoX by sharing the GPU memory location of the model weights.

It currently supports llama models and pythia models.

### Building the package

Here is a reference on how the package has been built before, using conda:
(Note this should be taken as a reference, and may not work as is due to your system configuration)

```bash
# cd to the synth vllm directory...
conda create -n vllm python=3.10
conda deactivate
conda activate vllm
conda install -y pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
conda install -y nvidia/label/cuda-12.1.0::cuda-toolkit
conda install -y nvidia/label/cuda-12.1.0::cuda-cudart
conda install -y nvidia/label/cuda-12.1.0::cuda-compiler
conda install -y nvidia/label/cuda-12.1.0::cuda-nvcc
conda install -y nvidia/label/cuda-12.1.0::cuda-profiler-api
conda install -y nvidia/label/cuda-12.1.0::cuda-cudarty
conda install -y -c nvidia cuda-nvprof=12.1
conda install -y conda-forge::cuda-version=12.1
conda install -y gcc_linux-64=12.3.0
conda install -y -c conda-forge gxx_linux-64=12.3.0
pip install -e .
```

## Training

If you haven't already, run this command to generate the weights:
```bash
python tools/ckpts/convert_hf_llama_to_neox.py --tp 4 --model meta-llama/Meta-Llama-3-8B-Instruct --model_path checkpoints/neox_converted/llama3-8b-instruct
```

[online_example.sh](online_example.sh), [online_data_example_llama3.py](online_data_example_llama3.py) is an example of
how to train a model using the synth-vllm package on a single node.

This assumes you are using a conda environment with NeoX installed under the name `neox`.

To run the example, execute the following commands:

```bash
# It may be preferable to run these in two separate terminals
python post-training/online_data_example_llama3.py &
bash post-training/online_example.sh
```

This will train a model using the synth-vllm package on the llama3-8b-instruct model. It will optimize a positive reward
from a sentiment classifier.
2 changes: 2 additions & 0 deletions post-training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Examples for running post-training with ultrafeedback data for SFT/DPO/RM training.

For REINFORCE style training, see [Online Training](OnlineTraining.MD).

```bash
python tools/ckpts/convert_hf_llama_to_neox.py --tp 4 --model meta-llama/Meta-Llama-3-8B-Instruct --model_path checkpoints/neox_converted/llama3-8b-instruct
```
Expand Down

0 comments on commit 3adf920

Please sign in to comment.