Skip to content

Commit

Permalink
Merge branch 'main' into patch_conversion_scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony authored Jun 7, 2024
2 parents 67cc679 + 2382bd4 commit 714b299
Show file tree
Hide file tree
Showing 14 changed files with 157 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Pull Request

on: [pull_request]
on: [pull_request, workflow_dispatch]

jobs:
pre-commit:
Expand Down
22 changes: 20 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ To install the remaining basic dependencies, run:
pip install -r requirements/requirements.txt
pip install -r requirements/requirements-wandb.txt # optional, if logging using WandB
pip install -r requirements/requirements-tensorboard.txt # optional, if logging via tensorboard
python ./megatron/fused_kernels/setup.py install # optional, if using fused kernels
```

from the repository root.
Expand All @@ -106,6 +105,16 @@ from the repository root.
</aside>

### Fused Kernels
We now support AMD GPUs (MI100, MI250X) through JIT fused-kernel compilation. Fused kernels will be built and loaded as needed. To avoid waiting during job launching, you can also do the following for manual pre-build:

```python
python
from megatron.fused_kernels import load
load()
```
This will automatically adapts building process over different GPU vendors (AMD, NVIDIA) without platform specific code changes. To further test fused kernels using `pytest`, use `pytest tests/model/test_fused_kernels.py`

### Flash Attention

To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
Expand Down Expand Up @@ -640,7 +649,7 @@ If you need to supply a hostfile for use with the MPI-based DeepSpeed launcher,
# Profiling
We support profiling with Nsight Systems and PyTorch Memory Profiling.
We support profiling with Nsight Systems, the PyTorch Profiler, and PyTorch Memory Profiling.
## Nsight Systems Profiling
Expand All @@ -656,6 +665,15 @@ The generated output file can then by viewed with the Nsight Systems GUI:
![Alt text](images/nsight_profiling.png)
## PyTorch Profiling
To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop`.
The PyTorch profiler will save traces to your `tensorboard` log directory. You can view these traces within
TensorBoard by following the steps [here](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html).
![Alt text](images/pytorch_profiling.png)
## PyTorch Memory Profiling
To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path`.
Expand Down
2 changes: 1 addition & 1 deletion configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Below is an example configuration `.yaml` to train a ~160M parameter GPT model.

For a detailed list of all the arguments available for neox, see [neox_arguments.md](neox_arguments.md)

Note: yaml arguments may be formatted with either '-' or '_'. The standard separator used is a '_' as shown in the example configurations below. However, the use of '-' as a separator may be deprecated in the future.
Note: yaml arguments may be formatted with either '-' or '\_'. The standard separator used is a '\_' as shown in the example configurations below. However, the use of '-' as a separator may be deprecated in the future.
```yaml
# GPT-3 pretraining setup
{
Expand Down
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = b6cb77e
Default = 8451671

current git hash of repository

Expand Down
Binary file added images/pytorch_profiling.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions megatron/data/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}

} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down Expand Up @@ -660,9 +660,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down
6 changes: 3 additions & 3 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def _cpp_extention_load_helper(
srcpath / "fused_rotary_positional_embedding.cpp",
srcpath / "fused_rotary_positional_embedding_cuda.cu",
]
fused_rotary_positional_embedding_cuda = _cpp_extention_load_helper(
"fused_rotary_positional_embedding_cuda",
fused_rotary_positional_embedding = _cpp_extention_load_helper(
"fused_rotary_positional_embedding",
sources,
extra_cuda_flags,
extra_include_paths,
Expand Down Expand Up @@ -174,7 +174,7 @@ def load_fused_kernels():
print(e)
print("=" * 100)
print(
f"ERROR: Fused kernels configured but not properly installed. Please run `pip install {str(srcpath)}` to install them"
f"ERROR: Fused kernels configured but not properly installed. Please run `from megatron.fused_kernels import load()` then `load()` to load them correctly"
)
print("=" * 100)
exit()
Expand Down
7 changes: 5 additions & 2 deletions megatron/model/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
from torch.nn import LayerNorm as LayerNorm
from .fused_layer_norm import MixedFusedLayerNorm


def get_norm(neox_args):
Expand All @@ -23,7 +22,11 @@ def get_norm(neox_args):
eps = neox_args.rms_norm_epsilon
elif neox_args.norm == "layernorm":
eps = neox_args.layernorm_epsilon
norm = MixedFusedLayerNorm if neox_args.layernorm_fusion else LayerNorm
if neox_args.layernorm_fusion:
from .fused_layer_norm import MixedFusedLayerNorm
norm = MixedFusedLayerNorm
else:
norm = LayerNorm
elif neox_args.norm == "scalenorm":
eps = neox_args.scalenorm_epsilon
norm = ScaleNorm
Expand Down
21 changes: 6 additions & 15 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def from_ymls(cls, paths_to_yml_files: List[str], overwrite_values: Dict = None)
config_files = dict()
# iterate of all to be loaded yaml files
for conf_file_name in paths_to_yml_files:

# load file
with open(conf_file_name) as conf_file:
conf = yaml.load(conf_file, Loader=yaml.FullLoader)
Expand Down Expand Up @@ -477,7 +476,6 @@ def get_extra_deepspeed_args(self):
return extra_ds_args

def get_deepspeed_main_args(self):

args_list = list()

if self.autotuning_run is not None:
Expand Down Expand Up @@ -796,14 +794,11 @@ def calculate_batch_parameters(

# either none of the three parameters are provided or just gradient_accumulation_step is provided
else:
assert (
False
), "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided"
assert False, "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided"
return int(train_batch), int(micro_batch), int(grad_acc)

@staticmethod
def check_batch_parameters(dp_world_size, train_batch, micro_batch, grad_acc):

assert (
train_batch > 0
), f"Train batch size: {train_batch} has to be greater than 0"
Expand Down Expand Up @@ -1033,10 +1028,7 @@ def calculate_derived(self):
# Update 'is pipe parallel' flag
# if we set pipe_parallel_size to 0 or 1, GPT2ModelPipe.to_sequential() is called, and we run training with
# the sequential model without the PipelineModule wrapper to avoid the overhead it incurs
self.update_value(
"is_pipe_parallel",
self.pipe_parallel_size > 1 and self.moe_num_experts == 1,
)
self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1)
if self.moe_num_experts > 1:
assert not (
self.is_pipe_parallel or self.pipe_parallel_size > 1
Expand Down Expand Up @@ -1070,8 +1062,8 @@ def calculate_derived(self):
), "Mamba does not yet have dropout implemented"
if "rwkv" in self.attention_config:
assert (
not self.is_pipe_parallel and self.model_parallel_size == 1
), "RWKV not currently compatible with parallelism"
self.model_parallel_size == 1
), "RWKV not currently compatible with model parallelism"
if isinstance(self.zero_stage, int):
assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV"
assert (
Expand Down Expand Up @@ -1106,8 +1098,8 @@ def calculate_derived(self):
if "flash" in self.attention_config:
_flash_version = packaging.version.Version(version("flash-attn"))
if self.sliding_window_width is not None:
assert _flash_version >= packaging.version.Version(
"2.3.0"
assert (
_flash_version >= packaging.version.Version("2.3.0")
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."
if self.pos_emb == "alibi":
if not _flash_version >= packaging.version.Version("2.4.0.post1"):
Expand Down Expand Up @@ -1234,7 +1226,6 @@ def validate_values(self):

# Parameters sharing does not work with torch DDP.
if (self.num_unique_layers is not None) and (self.num_layers is not None):

if not (self.num_unique_layers <= self.num_layers):
error_message = (
self.__class__.__name__
Expand Down
22 changes: 22 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,28 @@ def train(

# to monitor if we've skipped many iterations in a row and trigger an early exit
overflow_monitor = OverflowMonitor(optimizer)

if neox_args.profile:
schedule = torch.profiler.schedule(
wait=neox_args.profile_step_start,
warmup=1,
active=neox_args.profile_step_stop - neox_args.profile_step_start,
)
prof = torch.profiler.profile(
schedule=schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
neox_args.tensorboard_dir
),
record_shapes=True,
profile_memory=True,
with_flops=True,
with_modules=True,
with_stack=True,
)
prof.start()
while iteration < neox_args.train_iters:
if neox_args.profile:
prof.step()
if neox_args.profile and iteration == neox_args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
loss_dict, skipped_iter = train_step(
Expand All @@ -983,6 +1004,7 @@ def train(
)
if neox_args.profile and iteration == neox_args.profile_step_stop:
torch.cuda.cudart().cudaProfilerStop()
prof.stop()
iteration += 1
neox_args.iteration = iteration
if neox_args.precision == "fp16":
Expand Down
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
git+https://github.com/EleutherAI/DeeperSpeed.git@02e2ebf7dee6aaab3d89094ed470a4609763c742#egg=deepspeed
deepspeed@git+https://github.com/EleutherAI/DeeperSpeed.git@02e2ebf7dee6aaab3d89094ed470a4609763c742#egg=deepspeed
ftfy>=6.0.1
git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
lm_dataformat@git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
huggingface_hub>=0.11.0
jinja2==3.1.4
lm_eval>=0.4.0,<=0.4.1
Expand Down
77 changes: 76 additions & 1 deletion tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pytest --forked tests/model/test_model_generation.py

Some tests can run on cpu only. These are marked with the decorator @pytest.mark.cpu.
The test cases for cpu can be run with:
````
```
pytest tests -m cpu
```

Expand All @@ -49,3 +49,78 @@ if You see this kind of error:
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
```
It usually means that you used some pytorch.cuda function before the test creates the processes. However just importing `from torch.utils import cpp_extension` can also trigger this.


## CPU Test Integration

Tests can be run against physical CPUs through GitHub Actions. To have tests run on the physical CPU test, here is generally how the CI should be written:

### runs-on

The CI needs to be written to target the CPU Github Action runner. The jobs that need to run on CPU should use the hardware runner's labels:
```yaml
jobs:
cpu-test-job:
runs-on: [ 'self-hosted', 'aws', 'test'] # these labels tell GitHub to execute on the runner with the 'aws' and 'test' labels
```
### Software dependencies
Hardware tests that need python and docker should install them as part of the test execution to make sure the tests run as expected:
```yaml
steps:
# sample syntax to setup python with pip
- uses: actions/setup-python@v4
with:
python-version: "3.8"
cache: "pip"

# sample setup of docker (there's no official Docker setup action)
- name: Docker setup
run: | # taken from Docker's installation page: https://docs.docker.com/engine/install/ubuntu/
# Add Docker's official GPG key:
sudo apt-get update
sudo apt-get install ca-certificates curl
sudo install -m 0755 -d /etc/apt/keyrings
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
sudo chmod a+r /etc/apt/keyrings/docker.asc
# Add the repository to Apt sources:
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \
sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
sudo apt-get update
sudo apt-get install docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin -y
```
Any other software dependencies should be assumed to be missing and installed as part of the CI.
### Using Docker image
Using the Docker image and running tests in a container is recommended to resolve environment issues. There is a modified docker-compose.yml in tests/cpu_tests directory that is recommended to be used for CPU tests:
```bash
cp tests/cpu_tests/docker-compose.yml .
# export any env variables here that should be used:
export NEOX_DATA_PATH='./data/enwik8'
docker compose run -d --build --name $CONTAINER gpt-neox tail -f /dev/null
# then can set up and run tests in the container using docker exec
docker exec $CONTAINER pip install -r /workspace/requirements-dev.txt
# etc.
# please clean up the container as part of the CI:
docker rm $CONTAINER
```

At the time of writing there is no built-in method to provide an offline-built Docker image to `jobs.<job-id>.container`.

### Using existing CPU test CI

There is an existing CPU test workflow that can be included in existing CI:

```yaml
steps:
- name: Run CPU Tests
uses:
target_test_ref: $GITHUB_REF # replace with the ref/SHA that the tests should be run on
# have a look at the reusable workflow here: https://github.com/EleutherAI/gpt-neox/blob/main/tests/cpu_tests/action.yml
```
4 changes: 1 addition & 3 deletions tests/model/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
)


@pytest.mark.xfail(
reason="ModuleNotFoundError: No module named 'scaled_masked_softmax_cuda'"
)
@pytest.mark.xfail(reason="SystemExit: None")
def test_load_fused_kernels():
load()
try:
Expand Down
15 changes: 13 additions & 2 deletions tools/ckpts/convert_hf_to_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,27 @@ def shard_sequential_mp(num_mp_ranks, sequential):
ranks = {x: dict() for x in range(num_mp_ranks)}
for k, v in sequential.items():
if reduce(
np.logical_or,
[
x in k
for x in [
"dense_4h_to_h.bias",
"attention.dense.bias",
]
],
):
# Divide by tp_size since they get added together
for x in range(num_mp_ranks):
ranks[x][k] = v / num_mp_ranks
elif reduce(
np.logical_or,
[
x in k
for x in [
"layernorm",
"rotary_emb",
"dense_4h_to_h.bias",
"norm.weight",
"norm.bias",
"attention.dense.bias",
]
],
):
Expand Down

0 comments on commit 714b299

Please sign in to comment.