From 422769ad2089d72fa72fcdc5e62a1b0eb1e3db60 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 17 Jul 2024 11:50:21 +0000 Subject: [PATCH 1/7] remove github submodule --- .github/workflows/push-docker-image.yml | 5 +---- Dockerfile | 2 -- README.md | 28 +++---------------------- hivemind_source | 1 - pydantic_config | 1 - requirements.txt | 7 +++++-- 6 files changed, 9 insertions(+), 35 deletions(-) delete mode 160000 hivemind_source delete mode 160000 pydantic_config diff --git a/.github/workflows/push-docker-image.yml b/.github/workflows/push-docker-image.yml index bb1614d..47296eb 100644 --- a/.github/workflows/push-docker-image.yml +++ b/.github/workflows/push-docker-image.yml @@ -20,10 +20,7 @@ jobs: # Link to discussion: https://github.com/orgs/community/discussions/25678 - name: Checkout - uses: actions/checkout@v3 - with: - submodules: true - + uses: actions/checkout@v3 - name: Docker meta id: meta uses: crazy-max/ghaction-docker-meta@v2 diff --git a/Dockerfile b/Dockerfile index f556805..4ddc56c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,8 +35,6 @@ RUN echo "export PATH=\"/opt/conda/bin:/root/.cargo/bin:\$PATH\"" >> /root/.bash # Install Python dependencies (The gradual copies help with caching) WORKDIR open_diloco RUN pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu -COPY hivemind_source hivemind_source -RUN pip install --no-cache-dir ./hivemind_source COPY requirements.txt requirements.txt RUN pip install --no-cache-dir -r requirements.txt COPY requirements-dev.txt requirements-dev.txt diff --git a/README.md b/README.md index 31ee3d1..b0f3d97 100644 --- a/README.md +++ b/README.md @@ -30,18 +30,6 @@ source .venv/bin/activate Install python dependencies: ```bash -# Hivemind -cd hivemind_source -pip install . -cp build/lib/hivemind/proto/* hivemind/proto/. -pip install -e ".[all]" -cd .. -# Requirements -pip install -r requirements.txt -# Others -pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu -pip install -e ./pydantic_config -# OpenDiLoCo pip install . ``` @@ -49,7 +37,7 @@ Optionally, you can install flash-attn to use Flash Attention 2. This requires your system to have cuda compiler set up. ``` # (Optional) flash-attn -pip install flash-attn==2.5.8 +pip install flash-attn>=2.5.8 ``` ## Docker container @@ -305,20 +293,10 @@ We recommend using `bf16` to avoid scaling and desynchronization issues with hiv # Debugging Issues -1. `hivemind` or `pydantic_config` - If you are having issues with `hivemind` or `pydantic_config`, the issue could be related to submodules. - You can clean and reinitialize the submodules from the root of the repository with the following commands: - - ``` - git submodule deinit -f . - git clean -xdf - git submodule update --init --recursive - ``` - -2. `RuntimeError: CUDA error: invalid device ordinal` +1. `RuntimeError: CUDA error: invalid device ordinal` A possible culprit is that your `--nproc-per-node` argument for the torchrun launcher is set incorrectly. Please set it to an integer less than equal to the number of gpus you have on your machine. -3. `torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate...` +2. `torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate...` A possible culprit is that your `--per-device-train-batch-size` is too high. Try a smaller value. diff --git a/hivemind_source b/hivemind_source deleted file mode 160000 index ad080ed..0000000 --- a/hivemind_source +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ad080ed0461e8e68fbed4d28b735ccfbdd84113e diff --git a/pydantic_config b/pydantic_config deleted file mode 160000 index 8e19e05..0000000 --- a/pydantic_config +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8e19e05d20c0acc7efc27622c0f5c41f3d7c78b1 diff --git a/requirements.txt b/requirements.txt index 52e9232..e918dce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,9 @@ transformers~=4.40 datasets>=2.19.1 -wandb==0.16.4 +wandb>=0.16.4 cyclopts>=2.6.1 fsspec[gcs]>=2024.3.1 -torch==2.3.1 +torch>=2.3.1 +hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff9 +pydantic_config @ git+https://github.com/samsja/pydantic_config.git@8e19e05 + From ff545dff069af23e27e42347951121a2054ad39e Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 17 Jul 2024 11:54:25 +0000 Subject: [PATCH 2/7] add flash attention build to dockerfile --- Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile b/Dockerfile index 4ddc56c..0a2b942 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,6 +35,7 @@ RUN echo "export PATH=\"/opt/conda/bin:/root/.cargo/bin:\$PATH\"" >> /root/.bash # Install Python dependencies (The gradual copies help with caching) WORKDIR open_diloco RUN pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu +RUN pip install flash-attn>=2.5.8 COPY requirements.txt requirements.txt RUN pip install --no-cache-dir -r requirements.txt COPY requirements-dev.txt requirements-dev.txt From 0c92f363e7115e1e4d7fd3eefdd1f82a71ffb6ea Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 16 Jul 2024 16:04:05 +0000 Subject: [PATCH 3/7] add llama i add llama --- open_diloco/configs/config_1b.json | 16 +- open_diloco/configs/config_2m.json | 15 +- open_diloco/llama.py | 608 +++++++++++++++++++++++++++++ open_diloco/train_fsdp.py | 31 +- requirements.txt | 2 +- tests/test_llama.py | 207 ++++++++++ tests/test_training/test_train.py | 11 +- 7 files changed, 859 insertions(+), 31 deletions(-) create mode 100644 open_diloco/llama.py create mode 100644 tests/test_llama.py diff --git a/open_diloco/configs/config_1b.json b/open_diloco/configs/config_1b.json index cdd57c0..8a7584d 100644 --- a/open_diloco/configs/config_1b.json +++ b/open_diloco/configs/config_1b.json @@ -1,13 +1,9 @@ { - "architectures": [ - "LlamaForCausalLM" - ], - "model_type": "llama", - "hidden_size": 2048, + "name": "llama", + "n_embd": 2048, "intermediate_size": 5632, - "num_attention_heads": 32, - "num_hidden_layers": 22, - "use_cache": false, - "rms_norm_eps": 1e-05, - "num_key_value_heads": 4 + "n_head": 32, + "n_layer": 22, + "n_query_groups": 4, + "vocab_size": 1024 } diff --git a/open_diloco/configs/config_2m.json b/open_diloco/configs/config_2m.json index 12a7825..c2bdac6 100644 --- a/open_diloco/configs/config_2m.json +++ b/open_diloco/configs/config_2m.json @@ -1,15 +1,10 @@ { - "architectures": [ - "LlamaForCausalLM" - ], - "model_type": "llama", - "hidden_size": 64, + "name": "llama_2m", + "n_embd": 64, "intermediate_size": 256, - "num_attention_heads": 2, - "num_hidden_layers": 2, - "rms_norm_eps": 1e-05, - "use_cache": false, + "n_head": 2, + "n_layer": 2, "vocab_size": 1024 - } +} \ No newline at end of file diff --git a/open_diloco/llama.py b/open_diloco/llama.py new file mode 100644 index 0000000..95ccae5 --- /dev/null +++ b/open_diloco/llama.py @@ -0,0 +1,608 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Full definition of a decoder-only transformer-based language model, all of it in this single file. + +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and +https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. +""" + +import math +from typing import Any, Iterable, Literal, Optional, Tuple +from pydantic import model_validator + +import torch +import torch.nn as nn +from typing_extensions import Self +from pydantic_config import BaseConfig + +try: + import xformers.ops as xops + + XFORMERS_AVAILABLE = True +except ImportError: + XFORMERS_AVAILABLE = False + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + + FLASH_AVAILABLE = True +except ImportError: + FLASH_AVAILABLE = False + +from einops import rearrange + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n + k - (n % k) + + +class Config(BaseConfig): + name: str = "" + scale_embeddings: bool = False + block_size: int = 4096 + vocab_size: int = 50254 + padding_multiple: int = 512 + padded_vocab_size: Optional[int] = None + n_layer: int = 16 + n_head: int = 32 + head_size: int + n_embd: int = 4096 + rotary_percentage: float = 0.25 + parallel_residual: bool = True + bias: bool = False + lm_head_bias: bool = False + # to use multi-head attention (MHA), set this to `n_head` (default) + # to use multi-query attention (MQA), set this to 1 + # to use grouped-query attention (GQA), set this to a value in between + # Example with `n_head=4` + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ │ │ │ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + # MHA GQA MQA + # n_query_groups=4 n_query_groups=2 n_query_groups=1 + # + # credit https://arxiv.org/pdf/2305.13245.pdf + n_query_groups: Optional[int] = None + shared_attention_norm: bool = False + norm_eps: float = 1e-5 + mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" + gelu_approximate: str = "none" + intermediate_size: int + rope_condense_ratio: int = 1 + rope_base: int = 10000 + n_expert: int = 0 + n_expert_per_token: int = 0 + attention_impl: Literal["sdpa", "fa", "xformers"] = "sdpa" + padded_vocab_size: int + rope_n_elem: Optional[int] = None + + @model_validator(mode="before") + def set_padded_vocab_size(cls, values: dict[str, Any]): + """Set the padded vocab size to the next multiple of 64 if not provided.""" + vocab_size = values.get("vocab_size") + padded_vocab_size = values.get("padded_vocab_size") + + if padded_vocab_size is None and vocab_size is not None: + values["padded_vocab_size"] = find_multiple(vocab_size, 64) + return values + + @model_validator(mode="before") + def set_head_size(cls, values: dict[str, Any]): + head_size = values.get("head_size") + if head_size is None: + n_embd = values.get("n_embd") + n_head = values.get("n_head") + assert n_embd % n_head == 0 + values["head_size"] = n_embd // n_head + + return values + + @model_validator(mode="before") + def set_n_query_groups(cls, values: dict[str, Any]): + n_query_groups = values.get("n_query_groups") + n_head = values.get("n_head") + if n_query_groups is None: + values["n_query_groups"] = n_head + else: + assert n_head % n_query_groups == 0 + return values + + @model_validator(mode="before") + def set_intermediate_size(cls, values: dict[str, Any]): + intermediate_size = values.get("intermediate_size") + mlp_class_name = values.get("mlp_class_name") + if intermediate_size is None: + if mlp_class_name == "LLaMAMLP": + raise ValueError("The config needs to set the `intermediate_size`") + values["intermediate_size"] = 4 * values.get("n_embd") + return values + + @model_validator(mode="after") + def set_rope_n_elem(self): + self.rope_n_elem = int(self.rotary_percentage * self.head_size) + return self + + +class GPT(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + @property + def max_seq_length(self) -> int: + return self._max_seq_length + + @max_seq_length.setter + def max_seq_length(self, value: int) -> None: + """ + When doing inference, the sequences used might be shorter than the model's context length. + This allows setting a smaller number to avoid allocating unused memory + """ + if value > self.config.block_size: + raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}") + self._max_seq_length = value + if not hasattr(self, "cos"): + # first call + cos, sin = self.rope_cache() + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + # override + elif value != self.cos.size(0): + self.cos, self.sin = self.rope_cache(device=self.cos.device) + # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know + # if the kv cache is expected + + def reset_parameters(self) -> None: + # Trigger resetting the rope-cache + self.cos, self.sin = self.rope_cache(device=self.cos.device) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`.""" + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seqlens: Optional[Iterable[int]] = None + ) -> torch.Tensor: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + if self.config.scale_embeddings: + x = x * (self.config.n_embd**0.5) + + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos, seqlens) + x = self.transformer.ln_f(x) + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return build_rope_cache( + seq_len=self.max_seq_length, + n_elem=self.config.rope_n_elem, + device=device, + condense_ratio=self.config.rope_condense_ratio, + base=self.config.rope_base, + ) + + def set_kv_cache( + self, + batch_size: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + if rope_cache_length is None: + rope_cache_length = self.cos.size(-1) + max_seq_length = self.max_seq_length + + # initialize the kv cache for all blocks + for block in self.transformer.h: + block.attn.kv_cache = block.attn.build_kv_cache( + batch_size, max_seq_length, rope_cache_length, device, dtype + ) + + if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: + # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + self.mask_cache = build_mask_cache(max_seq_length, device) + + def clear_kv_cache(self) -> None: + self.mask_cache = None + for block in self.transformer.h: + block.attn.kv_cache = None + + +class Block(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + self.norm_2 = None if config.shared_attention_norm else RMSNorm(config.n_embd, eps=config.norm_eps) + self.mlp = LLaMAMLP(config) + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + seqlens: Optional[Iterable[int]] = None, + ) -> torch.Tensor: + n_1 = self.norm_1(x) + h = self.attn(n_1, cos, sin, mask, input_pos, seqlens) + if self.config.parallel_residual: + n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) + x = self.mlp(n_2) + h + x + else: + if self.config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + x = h + x + x = self.mlp(self.norm_2(x)) + x + return x + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) + # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` + self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + seqlens: Optional[Iterable[int]] = None, + ) -> torch.Tensor: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + qkv = self.attn(x) + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) + qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + + # maybe repeat k and v if for the non multi-head attention cases + # training: flash attention requires it + # inference: multi-query would require a full kv cache so avoid it to limit its memory usage + if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): + k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + + q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) + k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) + v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) + + q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + + if input_pos is not None: + if not isinstance(self.kv_cache, KVCache): + raise TypeError("You need to call `gpt.set_kv_cache()`") + k, v = self.kv_cache(input_pos, k, v) + + scale = 1.0 / math.sqrt(self.config.head_size) + y = self.scaled_dot_product_attention(q, k, v, scale, mask, seqlens) + y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side + + # output projection + return self.proj(y) + + def scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + mask: Optional[torch.Tensor] = None, + seqlens: Optional[Iterable[int]] = None, + ) -> torch.Tensor: + if seqlens is not None: + if mask is not None: + raise ValueError("context stuffing is not compatible with custom mask") + + if self.config.attention_impl == "sdpa": + raise ValueError("context stuffing is not supported with sdpa") + elif self.config.attention_impl == "xformers": + return self._xformers_attention_with_seqlens(q, k, v, scale, seqlens) + elif self.config.attention_impl == "fa": + return self._fa_attention_with_seqlens(q, k, v, scale, seqlens) + else: + raise ValueError(f"Unknown attention implementation: {self.config.attention_impl}") + + if self.config.attention_impl == "sdpa": + return self._sdpa_attention(q, k, v, scale, mask) + elif self.config.attention_impl == "xformers": + if XFORMERS_AVAILABLE: + return self._xformers_attention(q, k, v, scale, mask) + else: + raise ImportError("Xformers is not available, please install xformers library to use it.") + elif self.config.attention_impl == "fa": + if FLASH_AVAILABLE: + return self._fa_attention(q, k, v, scale, mask) + else: + raise ImportError("Flash attention is not available, please install flash_attn library to use it.") + else: + raise ValueError(f"Unknown attention implementation: {self.config.attention_impl}") + + def _sdpa_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None + ) + return y.transpose(1, 2) + + def _xformers_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if mask is not None: + raise ValueError("custom mask not supporting with xformers") + + attn_bias = xops.LowerTriangularMask() + + q = rearrange(q, "b n t h -> b t n h") + k = rearrange(k, "b n t h -> b t n h") + v = rearrange(v, "b n t h -> b t n h") + + return xops.memory_efficient_attention( + query=q, + key=k, + value=v, + scale=scale, + attn_bias=attn_bias, + op=xops.MemoryEfficientAttentionFlashAttentionOp, + ) + + def _xformers_attention_with_seqlens( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, seqlens: Iterable[int] + ): + attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen=seqlens) + attn_bias = attn_bias.make_causal() + + q = rearrange(q, "b n t h -> 1 (b t) n h") + k = rearrange(k, "b n t h -> 1 (b t) n h") + v = rearrange(v, "b n t h -> 1 (b t) n h") + + return xops.memory_efficient_attention( + query=q, + key=k, + value=v, + scale=scale, + attn_bias=attn_bias, + op=xops.MemoryEfficientAttentionFlashAttentionOp, + ) + + def _fa_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + q = rearrange(q, "b n t h -> b t n h") + k = rearrange(k, "b n t h -> b t n h") + v = rearrange(v, "b n t h -> b t n h") + # q/k/b is [b, nh, t, hs] but fa2 expected [b , t, nh, hs] + return flash_attn_func(q, k, v, causal=True, softmax_scale=scale) + + def _fa_attention_with_seqlens( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, seqlens: Iterable[int] + ): + scale = 1.0 / math.sqrt(self.config.head_size) + b = q.shape[0] + seqlens = torch.tensor(seqlens, dtype=torch.int32) + cu_seqlens = torch.concat([torch.tensor([0]), seqlens.cumsum(0)], dim=0).to(torch.int32).to(q.device) + max_seqlen = seqlens.max().item() + + q = rearrange(q, "b n t h -> (b t) n h") + k = rearrange(k, "b n t h -> (b t) n h") + v = rearrange(v, "b n t h -> (b t) n h") + # q/k/v is [b, nh, t, hs] but fa expected [b * t, nh, hs] + y = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + softmax_scale=scale, + ) + + y = rearrange(y, "(b t) n h -> b t n h", b=b) + return y + + def build_kv_cache( + self, + batch_size: int, + max_seq_length: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> "KVCache": + heads = 1 if self.config.n_query_groups == 1 else self.config.n_head + v_shape = (batch_size, heads, max_seq_length, self.config.head_size) + if rope_cache_length is None: + if self.config.rotary_percentage != 1.0: + raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") + k_shape = v_shape + else: + k_shape = ( + batch_size, + heads, + max_seq_length, + rope_cache_length + self.config.head_size - self.config.rope_n_elem, + ) + return KVCache(k_shape, v_shape, device=device, dtype=dtype) + + +class LLaMAMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + return self.proj(x) + + +def build_rope_cache( + seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + + return torch.cos(idx_theta), torch.sin(idx_theta) + + +def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + head_size = x.size(-1) + x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) + x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + roped = (x * cos) + (rotated * sin) + return roped.to(dtype=x.dtype) + + +class KVCache(nn.Module): + def __init__( + self, + k_shape: Tuple[int, int, int, int], + v_shape: Tuple[int, int, int, int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) + self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) + + def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(k.dtype) + self.v = self.v.to(v.dtype) + # update the cache + k = self.k.index_copy_(2, input_pos, k) + v = self.v.index_copy_(2, input_pos, v) + return k, v + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.k) + torch.nn.init.zeros_(self.v) + + +def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor: + ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) + return torch.tril(ones).unsqueeze(0).unsqueeze(0) + + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + self.add_unit_offset = add_unit_offset + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x = x.float() + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + x_normed = x_normed.to(dtype=dtype) + if self.add_unit_offset: + # Gemma model requires a unit offset + # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176 + return x_normed * (1 + self.weight) + return x_normed * self.weight + + def reset_parameters(self) -> None: + torch.nn.init.ones_(self.weight) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 666bcea..5d209ef 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -7,11 +7,13 @@ """ from functools import partial +import json import os import time from contextlib import nullcontext import datetime from typing import Any, Literal +from einops import rearrange import fsspec from pydantic import model_validator @@ -26,8 +28,6 @@ from transformers import ( AutoTokenizer, DataCollatorForLanguageModeling, - LlamaConfig, - LlamaForCausalLM, get_cosine_schedule_with_warmup, ) from torch.distributed.fsdp import ( @@ -39,7 +39,7 @@ from torch.distributed import broadcast_object_list from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer - +from open_diloco.llama import GPT, Config as ModelConfig from hivemind.dht.dht import DHT from hivemind.utils.networking import log_visible_maddrs @@ -114,7 +114,7 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: class Config(BaseConfig): - path_model: str = "PrimeIntellect/llama-150m-fresh" + llama_config: str | ModelConfig = "open_diloco/configs/config_1b.json" torch_compile: bool = True attn_implementation: str = "sdpa" # Data @@ -181,10 +181,14 @@ def tokenize_function(data): ) -def get_model(config: Config) -> LlamaForCausalLM: +def get_model(config: Config) -> GPT: # Load model - config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation) - return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) + if isinstance(config.llama_config, ModelConfig): + return GPT(config.llama_config) + else: + with open(config.llama_config) as f: + llama_config = ModelConfig(**json.load(f)) + return GPT(llama_config) def train(config: Config): @@ -389,9 +393,18 @@ def scheduler_fn(opt): batch[key] = batch[key].to("cuda") with model.no_sync() if is_accumulating else nullcontext(): - outputs = model(**batch) - loss = outputs.loss / gradient_accumulation_steps + inputs_ids = batch["input_ids"] + + input_ids = inputs_ids[:, :-1] + target = inputs_ids[:, 1:] + + output = model(input_ids) + + flatten_logits = rearrange(output, "b seq vocab -> (b seq) vocab") + flatten_target = rearrange(target, "b seq -> (b seq)") + loss = torch.nn.functional.cross_entropy(flatten_logits, flatten_target) + loss = loss / gradient_accumulation_steps loss_batch += loss.detach() scaler.scale(loss).backward() diff --git a/requirements.txt b/requirements.txt index e918dce..f5fba10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ fsspec[gcs]>=2024.3.1 torch>=2.3.1 hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff9 pydantic_config @ git+https://github.com/samsja/pydantic_config.git@8e19e05 - +xformers diff --git a/tests/test_llama.py b/tests/test_llama.py new file mode 100644 index 0000000..4733431 --- /dev/null +++ b/tests/test_llama.py @@ -0,0 +1,207 @@ +from typing import Tuple +import pytest +import torch + +from open_diloco.llama import GPT, CausalSelfAttention, Config, build_rope_cache +from xformers.ops.fmha.common import AttentionFwOpBase + + +@pytest.fixture +def config() -> Config: + return Config( + name="llama", + n_embd=64, + n_head=2, + n_layer=2, + vocab_size=1024, + ) + + +@pytest.mark.parametrize("attention_impl", ["sdpa", "xformers", "fa"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gpt(config: Config, attention_impl: str, dtype: torch.dtype): + config.attention_impl = attention_impl + _test_gpt(config, dtype) + + +def _test_gpt(config: Config, dtype: torch.dtype): + device = torch.device("cuda") + model = GPT(config).to(dtype).to(device) + + BATCH_SIZE = 16 + SEQ_LEN = 8 + VOCAB_SIZE = 1024 + + input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN)).to(device) + + output = model(input) + + assert output is not None + assert not output.isnan().any() + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("attention_impl", ["xformers", "fa"]) +def test_gpt_output(config: Config, dtype: torch.dtype, attention_impl: str): + """ + in this test we compare the output of the GPT with sdpa and xformers/fa + """ + + device = torch.device("cuda") + model = GPT(config).to(dtype).to(device) + + BATCH_SIZE = 16 + SEQ_LEN = 8 + VOCAB_SIZE = 1024 + + input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN)).to(device) + + ### SDPA + config.attention_impl = "sdpa" + output_sdpa = model(input) + + ### XFORMERS + config.attention_impl = attention_impl + output_xformers = model(input) + + ### TESTING + assert output_sdpa.shape == output_xformers.shape + + ### xformers has a higher tolerance + atol = AttentionFwOpBase.ERROR_ATOL[dtype] + rtol = AttentionFwOpBase.ERROR_RTOL[dtype] + torch.testing.assert_close(output_sdpa, output_xformers, atol=atol, rtol=rtol) + + +def get_cos_and_sin_attn(config: Config, seq_len: int, device) -> Tuple[torch.Tensor, torch.Tensor]: + cos, sin = build_rope_cache( + seq_len=seq_len, + n_elem=config.rope_n_elem, + device=device, + condense_ratio=config.rope_condense_ratio, + base=config.rope_base, + ) + cos = cos[:seq_len] + sin = sin[:seq_len] + return cos, sin + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("attention_impl", ["xformers", "fa"]) +def test_attn_output(config: Config, dtype: torch.dtype, attention_impl: str): + """ + in this test we compare the output of the GPT with sdpa and xformers/fa + """ + + device = torch.device("cuda") + model = CausalSelfAttention(config).to(dtype).to(device) + + BATCH_SIZE = 16 + SEQ_LEN = 8 + + input = torch.rand(BATCH_SIZE, SEQ_LEN, config.n_embd).to(dtype).to(device) + cos, sin = get_cos_and_sin_attn(config, SEQ_LEN, device) + + ### SDPA + config.attention_impl = "sdpa" + output_sdpa = model(input, cos, sin) + + ### XFORMERS + config.attention_impl = attention_impl + output_xformers = model(input, cos, sin) + + ### TESTING + assert output_sdpa.shape == output_xformers.shape + + ### xformers has a higher tolerance + atol = AttentionFwOpBase.ERROR_ATOL[dtype] + rtol = AttentionFwOpBase.ERROR_RTOL[dtype] + torch.testing.assert_close(output_sdpa, output_xformers, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("attention_impl", ["xformers", "fa"]) +def test_context_stuffing_attn(config: Config, dtype: torch.dtype, attention_impl: str): + """ + In this test we compare normal pad attention with stuffing. + + input is [[2, 1, 4, 8, PAD, PAD, PAD, PAD], [1, 4, 2, 7, PAD, PAD, PAD, PAD]] + for padded input and [[2, 1, 4, 8, 1, 4, 2, 7]] for stuffed input + + we then compare the output of the two and should be the same. + """ + device = torch.device("cuda") + model = CausalSelfAttention(config).to(dtype).to(device) + + config.attention_impl = attention_impl + + SEQ_LEN = 8 + + emb = torch.nn.Embedding(10, config.n_embd).to(dtype).to(device) + + pad_id = 0 + input_raw = ( + torch.Tensor([[2, 1, 4, 8, pad_id, pad_id, pad_id, pad_id], [1, 4, 2, 7, pad_id, pad_id, pad_id, pad_id]]) + .long() + .to(device) + ) + input = emb(input_raw) + + input_stuff_raw = torch.Tensor([[2, 1, 4, 8, 1, 4, 2, 7]]).long().to(device) + seqlens = [4, 4] + input_stuff = emb(input_stuff_raw) + + cos, sin = get_cos_and_sin_attn(config, SEQ_LEN, device) + + ### batch + output_ctx_stuff = model(input, cos, sin) + + output_ctx_stuff = output_ctx_stuff[:, :4, :] # remove padding token + + output_xformers_stuff = model(input_stuff, cos, sin, seqlens=seqlens) + output_xformers_stuff = output_xformers_stuff.reshape(2, 4, config.n_embd) + + ### TESTING + assert output_ctx_stuff.shape == output_xformers_stuff.shape + + ### xformers has a higher tolerance + atol = AttentionFwOpBase.ERROR_ATOL[dtype] + rtol = AttentionFwOpBase.ERROR_RTOL[dtype] + torch.testing.assert_close(output_ctx_stuff, output_xformers_stuff, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("attention_impl", ["xformers", "fa"]) +def test_context_stuffing_attn_2(config: Config, dtype: torch.dtype, attention_impl: str): + """ + this test is slightu different from the one above, it tests + that passing two time the same input in a stuff way yield the same results. + """ + device = torch.device("cuda") + + config.attention_impl = attention_impl + model = CausalSelfAttention(config).to(dtype).to(device) + + SEQ_LEN = 8 + + emb = torch.nn.Embedding(10, config.n_embd).to(dtype).to(device) + + seq = [2, 1, 4, 8] + input_stuff_raw = torch.Tensor([seq + seq]).long().to(device) + seqlens = [len(seq), len(seq)] + input_stuff = emb(input_stuff_raw) + + cos, sin = get_cos_and_sin_attn(config, SEQ_LEN, device) + + output = model(input_stuff, cos, sin, seqlens=seqlens) + + output_left = output[:, :4, :] + output_right = output[:, 4:, :] + + ### TESTING + assert output_left.shape == output_right.shape + + ### xformers has a higher tolerance + atol = AttentionFwOpBase.ERROR_ATOL[dtype] + rtol = AttentionFwOpBase.ERROR_RTOL[dtype] + torch.testing.assert_close(output_left, output_right, atol=atol, rtol=rtol) diff --git a/tests/test_training/test_train.py b/tests/test_training/test_train.py index 12e532c..543cd4e 100644 --- a/tests/test_training/test_train.py +++ b/tests/test_training/test_train.py @@ -10,6 +10,7 @@ from hivemind.dht.dht import DHT from open_diloco.train_fsdp import train, Config, ddp_setup, destroy_process_group, HvConfig +from open_diloco.llama import Config as ModelConfig @pytest.fixture(autouse=True) @@ -46,8 +47,16 @@ def random_available_port(): @pytest.fixture def config() -> Config: + model_config = ModelConfig( + name="llama_2m", + n_embd=64, + intermediate_size=256, + n_head=2, + n_layer=2, + vocab_size=1024, + ) return Config( - path_model="tests/models/llama-2m-fresh", + llama_config=model_config, fake_data=True, torch_compile=False, lr=1e-2, From 9a8f29c684a1f1e5a530d571938d5bc613079dfe Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 16 Jul 2024 17:16:42 +0000 Subject: [PATCH 4/7] fix config --- open_diloco/configs/config_150m.json | 21 +++++++++------------ open_diloco/configs/config_1b.json | 5 +++-- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/open_diloco/configs/config_150m.json b/open_diloco/configs/config_150m.json index 7e27472..80e118d 100644 --- a/open_diloco/configs/config_150m.json +++ b/open_diloco/configs/config_150m.json @@ -1,13 +1,10 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "model_type": "llama", - "hidden_size": 1024, - "intermediate_size": 2688, - "num_attention_heads": 16, - "num_hidden_layers": 12, - "use_cache": false, - "rms_norm_eps": 1e-05 -} +{ + "name": "llama150m", + "n_embd": 1024, + "intermediate_size": 4096, + "n_head": 16, + "n_layer": 12, + "vocab_size": 32000, + "block_size": 1024 +} \ No newline at end of file diff --git a/open_diloco/configs/config_1b.json b/open_diloco/configs/config_1b.json index 8a7584d..d2a98d8 100644 --- a/open_diloco/configs/config_1b.json +++ b/open_diloco/configs/config_1b.json @@ -1,9 +1,10 @@ { - "name": "llama", + "name": "llama1b", "n_embd": 2048, "intermediate_size": 5632, "n_head": 32, "n_layer": 22, "n_query_groups": 4, - "vocab_size": 1024 + "vocab_size": 32000, + "block_size": 1024 } From 682d365211d6914a29a5ed21006c0306eab1c37a Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 17 Jul 2024 08:57:58 +0000 Subject: [PATCH 5/7] add custom attn impl --- open_diloco/train_fsdp.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 5d209ef..782dbe7 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -116,7 +116,7 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: class Config(BaseConfig): llama_config: str | ModelConfig = "open_diloco/configs/config_1b.json" torch_compile: bool = True - attn_implementation: str = "sdpa" + attention_impl: Literal["sdpa", "fa", "xformers"] = "sdpa" # Data dataset_name_or_path: str = "allenai/c4" seq_length: int = 1024 @@ -184,11 +184,13 @@ def tokenize_function(data): def get_model(config: Config) -> GPT: # Load model if isinstance(config.llama_config, ModelConfig): - return GPT(config.llama_config) + llama_config = config.llama_config else: with open(config.llama_config) as f: llama_config = ModelConfig(**json.load(f)) - return GPT(llama_config) + + llama_config.attention_impl = config.attention_impl + return GPT(llama_config) def train(config: Config): From ab8c6dbc6d3adec3159649352062daca5de9217c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Mon, 22 Jul 2024 18:21:40 +0000 Subject: [PATCH 6/7] add optional profiler --- open_diloco/train_fsdp.py | 254 +++++++++++++++++++++----------------- 1 file changed, 140 insertions(+), 114 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 782dbe7..79ab2fd 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -37,6 +37,8 @@ ) from torch.distributed.device_mesh import DeviceMesh from torch.distributed import broadcast_object_list +from torch.profiler import profile, ProfilerActivity + from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer from open_diloco.llama import GPT, Config as ModelConfig @@ -140,6 +142,7 @@ class Config(BaseConfig): hv: HvConfig | None = None # if no hv config then hivemind is disabled fake_data: bool = False max_steps: int | None = None + profiler: bool = False def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader: @@ -182,7 +185,7 @@ def tokenize_function(data): def get_model(config: Config) -> GPT: - # Load model + # Load model1 if isinstance(config.llama_config, ModelConfig): llama_config = config.llama_config else: @@ -193,6 +196,27 @@ def get_model(config: Config) -> GPT: return GPT(llama_config) +def get_profiler(enable: bool, rank: int): + def trace_handler(p): + if rank == 0: + output = p.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_memory_usage", row_limit=20) + print(output) + p.export_chrome_trace("./trace_" + str(p.step_num) + ".json") + + if enable: + return profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=1, warmup=1, active=1), + on_trace_ready=trace_handler, + profile_memory=True, + with_stack=False, + experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True), + ) + + else: + return nullcontext() + + def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) @@ -375,151 +399,153 @@ def scheduler_fn(opt): log(f"starting from step {start_step}") loss_batch = 0 + with get_profiler(enable=config.profiler, rank=rank) as profiler: + for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps): + real_step = (step + 1) // gradient_accumulation_steps + is_accumulating = bool((step + 1) % gradient_accumulation_steps) - for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps): - real_step = (step + 1) // gradient_accumulation_steps - is_accumulating = bool((step + 1) % gradient_accumulation_steps) - - logging_activations_steps = ( - config.log_activations_steps is not None and real_step % config.log_activations_steps == 0 - ) - - if logging_activations_steps: - activation_monitor = ActivationNormMetric( - target_layers=TARGET_LAYER_ACTIVATIONS, - gradient_accumulation_steps=gradient_accumulation_steps, + logging_activations_steps = ( + config.log_activations_steps is not None and real_step % config.log_activations_steps == 0 ) - activation_monitor.register_metrics_hooks(model) - - for key in batch.keys(): - batch[key] = batch[key].to("cuda") - - with model.no_sync() if is_accumulating else nullcontext(): - inputs_ids = batch["input_ids"] - - input_ids = inputs_ids[:, :-1] - target = inputs_ids[:, 1:] - output = model(input_ids) - - flatten_logits = rearrange(output, "b seq vocab -> (b seq) vocab") - flatten_target = rearrange(target, "b seq -> (b seq)") - - loss = torch.nn.functional.cross_entropy(flatten_logits, flatten_target) - loss = loss / gradient_accumulation_steps - loss_batch += loss.detach() + if logging_activations_steps: + activation_monitor = ActivationNormMetric( + target_layers=TARGET_LAYER_ACTIVATIONS, + gradient_accumulation_steps=gradient_accumulation_steps, + ) + activation_monitor.register_metrics_hooks(model) - scaler.scale(loss).backward() + for key in batch.keys(): + batch[key] = batch[key].to("cuda") - if not is_accumulating: - if world_messenger_hv: - scaler.unscale_(optimizer=optimizer.inner_optimizer) - else: - scaler.unscale_(optimizer=optimizer) + with model.no_sync() if is_accumulating else nullcontext(): + inputs_ids = batch["input_ids"] - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # gradient clipping + input_ids = inputs_ids[:, :-1] + target = inputs_ids[:, 1:] - if world_messenger_hv: - optimizer.step(scaler=scaler) + output = model(input_ids) - # todo(sami): refactor to use built in pytorch mechanism to handle scaler manually - # should allow to just do scaler.step(optimizer) - else: - scaler.step(optimizer) + flatten_logits = rearrange(output, "b seq vocab -> (b seq) vocab") + flatten_target = rearrange(target, "b seq -> (b seq)") - scaler.update() + loss = torch.nn.functional.cross_entropy(flatten_logits, flatten_target) + loss = loss / gradient_accumulation_steps + loss_batch += loss.detach() - scheduler.step() - optimizer.zero_grad() + scaler.scale(loss).backward() - if logging_activations_steps: - activation_monitor.remove_hooks() + if not is_accumulating: + if world_messenger_hv: + scaler.unscale_(optimizer=optimizer.inner_optimizer) + else: + scaler.unscale_(optimizer=optimizer) - if config.hv is not None: - if int(real_step) % config.hv.local_steps == 0: - for param in model.parameters(): - torch.distributed.broadcast(param.data, src=0) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # gradient clipping - if rank == 0: - total_samples = real_step * config.total_batch_size - effective_step = real_step + if world_messenger_hv: + optimizer.step(scaler=scaler) - if config.hv is not None: - # Note that this assumes that we have the right amount of worker since t0. - # Not robust to off/on ramping - effective_step = real_step * config.hv.galaxy_size - total_samples = real_step * config.total_batch_size * config.hv.galaxy_size - - metrics = { - "Loss": loss_batch.item(), - "step": real_step, - "lr": [group["lr"] for group in optimizer.param_groups][0], - "Perplexity": torch.exp(loss_batch).item(), - "effective_step": effective_step, # at each step the we have compute total_batch_size. Independent of the number of GPUs - "total_samples": total_samples, - "time_taken": time.time() - current_time, - "tokens_per_second": config.seq_length * config.total_batch_size / (time.time() - current_time), - } + # todo(sami): refactor to use built in pytorch mechanism to handle scaler manually + # should allow to just do scaler.step(optimizer) + else: + scaler.step(optimizer) - if world_messenger_hv: - outer_lr = [group["lr"] for group in optimizer.state_averager.optimizer.param_groups][0] - num_peers = optimizer.tracker.global_progress.num_peers - if num_peers == 0: - num_peers = 1 + scaler.update() - metrics["outer_lr"] = outer_lr - metrics["num_peers"] = num_peers + scheduler.step() + optimizer.zero_grad() if logging_activations_steps: - metrics.update(activation_monitor.log_activations) - - current_time = time.time() + activation_monitor.remove_hooks() - wandb.log(metrics) + if config.hv is not None: + if int(real_step) % config.hv.local_steps == 0: + for param in model.parameters(): + torch.distributed.broadcast(param.data, src=0) + + if rank == 0: + total_samples = real_step * config.total_batch_size + effective_step = real_step + + if config.hv is not None: + # Note that this assumes that we have the right amount of worker since t0. + # Not robust to off/on ramping + effective_step = real_step * config.hv.galaxy_size + total_samples = real_step * config.total_batch_size * config.hv.galaxy_size + + metrics = { + "Loss": loss_batch.item(), + "step": real_step, + "lr": [group["lr"] for group in optimizer.param_groups][0], + "Perplexity": torch.exp(loss_batch).item(), + "effective_step": effective_step, # at each step the we have compute total_batch_size. Independent of the number of GPUs + "total_samples": total_samples, + "time_taken": time.time() - current_time, + "tokens_per_second": config.seq_length * config.total_batch_size / (time.time() - current_time), + } + + if world_messenger_hv: + outer_lr = [group["lr"] for group in optimizer.state_averager.optimizer.param_groups][0] + num_peers = optimizer.tracker.global_progress.num_peers + if num_peers == 0: + num_peers = 1 + + metrics["outer_lr"] = outer_lr + metrics["num_peers"] = num_peers + + if logging_activations_steps: + metrics.update(activation_monitor.log_activations) + + current_time = time.time() + + wandb.log(metrics) + + if config.hv is None: + log( + f"step: {real_step}, loss: {loss_batch.item()}, lr {[group['lr'] for group in optimizer.param_groups][0]}" + ) - if config.hv is None: - log( - f"step: {real_step}, loss: {loss_batch.item()}, lr {[group['lr'] for group in optimizer.param_groups][0]}" + # Save checkpoint every 'checkpoint_interval' steps + if config.checkpoint_interval is not None and real_step % config.checkpoint_interval == 0: + log(f"saving at step {real_step}, step {step+1}") + ckpt_path = os.path.join( + get_ckpt_folder(config.checkpoint_path, training_date, config.project, run_id), + f"model_step_{int(real_step)}", ) - # Save checkpoint every 'checkpoint_interval' steps - if config.checkpoint_interval is not None and real_step % config.checkpoint_interval == 0: - log(f"saving at step {real_step}, step {step+1}") - ckpt_path = os.path.join( - get_ckpt_folder(config.checkpoint_path, training_date, config.project, run_id), - f"model_step_{int(real_step)}", - ) - - if world_messenger_hv: - assert isinstance(optimizer, DiLoCoOptimizer) - with optimizer.tracker.pause_updates(): + if world_messenger_hv: + assert isinstance(optimizer, DiLoCoOptimizer) + with optimizer.tracker.pause_updates(): + save_checkpoint( + checkpoint_path=ckpt_path, + model=model, + optimizer=optimizer.inner_optimizer, + scheduler=scheduler, + outer_optimizer=optimizer.state_averager.optimizer, + loss=loss_batch.item(), + scaler=scaler, + data_loader=train_dataloader, + save_global_state=True, + ) + else: save_checkpoint( checkpoint_path=ckpt_path, model=model, - optimizer=optimizer.inner_optimizer, + optimizer=optimizer, scheduler=scheduler, - outer_optimizer=optimizer.state_averager.optimizer, loss=loss_batch.item(), scaler=scaler, data_loader=train_dataloader, - save_global_state=True, + save_global_state=rank == 0, ) - else: - save_checkpoint( - checkpoint_path=ckpt_path, - model=model, - optimizer=optimizer, - scheduler=scheduler, - loss=loss_batch.item(), - scaler=scaler, - data_loader=train_dataloader, - save_global_state=rank == 0, - ) - loss_batch = 0 + loss_batch = 0 + if config.profiler: + profiler.step() + if config.max_steps is not None and real_step >= config.max_steps: + break - if config.max_steps is not None and real_step >= config.max_steps: - break log("Training completed.") wandb.finish() From 00822da522caf6f07014419e987368af2a3ad279 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 23 Jul 2024 09:06:37 +0000 Subject: [PATCH 7/7] add optional profiler --- open_diloco/train_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 79ab2fd..347c8a4 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -206,7 +206,7 @@ def trace_handler(p): if enable: return profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=torch.profiler.schedule(wait=1, warmup=1, active=1), + schedule=torch.profiler.schedule(wait=1, warmup=1, active=2), on_trace_ready=trace_handler, profile_memory=True, with_stack=False,