diff --git a/open_diloco/configs/config_1b.json b/open_diloco/configs/config_1b.json index cdd57c0..8a7584d 100644 --- a/open_diloco/configs/config_1b.json +++ b/open_diloco/configs/config_1b.json @@ -1,13 +1,9 @@ { - "architectures": [ - "LlamaForCausalLM" - ], - "model_type": "llama", - "hidden_size": 2048, + "name": "llama", + "n_embd": 2048, "intermediate_size": 5632, - "num_attention_heads": 32, - "num_hidden_layers": 22, - "use_cache": false, - "rms_norm_eps": 1e-05, - "num_key_value_heads": 4 + "n_head": 32, + "n_layer": 22, + "n_query_groups": 4, + "vocab_size": 1024 } diff --git a/open_diloco/configs/config_2m.json b/open_diloco/configs/config_2m.json index 12a7825..c2bdac6 100644 --- a/open_diloco/configs/config_2m.json +++ b/open_diloco/configs/config_2m.json @@ -1,15 +1,10 @@ { - "architectures": [ - "LlamaForCausalLM" - ], - "model_type": "llama", - "hidden_size": 64, + "name": "llama_2m", + "n_embd": 64, "intermediate_size": 256, - "num_attention_heads": 2, - "num_hidden_layers": 2, - "rms_norm_eps": 1e-05, - "use_cache": false, + "n_head": 2, + "n_layer": 2, "vocab_size": 1024 - } +} \ No newline at end of file diff --git a/open_diloco/llama.py b/open_diloco/llama.py new file mode 100644 index 0000000..95ccae5 --- /dev/null +++ b/open_diloco/llama.py @@ -0,0 +1,608 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Full definition of a decoder-only transformer-based language model, all of it in this single file. + +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and +https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. +""" + +import math +from typing import Any, Iterable, Literal, Optional, Tuple +from pydantic import model_validator + +import torch +import torch.nn as nn +from typing_extensions import Self +from pydantic_config import BaseConfig + +try: + import xformers.ops as xops + + XFORMERS_AVAILABLE = True +except ImportError: + XFORMERS_AVAILABLE = False + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + + FLASH_AVAILABLE = True +except ImportError: + FLASH_AVAILABLE = False + +from einops import rearrange + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n + k - (n % k) + + +class Config(BaseConfig): + name: str = "" + scale_embeddings: bool = False + block_size: int = 4096 + vocab_size: int = 50254 + padding_multiple: int = 512 + padded_vocab_size: Optional[int] = None + n_layer: int = 16 + n_head: int = 32 + head_size: int + n_embd: int = 4096 + rotary_percentage: float = 0.25 + parallel_residual: bool = True + bias: bool = False + lm_head_bias: bool = False + # to use multi-head attention (MHA), set this to `n_head` (default) + # to use multi-query attention (MQA), set this to 1 + # to use grouped-query attention (GQA), set this to a value in between + # Example with `n_head=4` + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ │ │ │ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + # MHA GQA MQA + # n_query_groups=4 n_query_groups=2 n_query_groups=1 + # + # credit https://arxiv.org/pdf/2305.13245.pdf + n_query_groups: Optional[int] = None + shared_attention_norm: bool = False + norm_eps: float = 1e-5 + mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" + gelu_approximate: str = "none" + intermediate_size: int + rope_condense_ratio: int = 1 + rope_base: int = 10000 + n_expert: int = 0 + n_expert_per_token: int = 0 + attention_impl: Literal["sdpa", "fa", "xformers"] = "sdpa" + padded_vocab_size: int + rope_n_elem: Optional[int] = None + + @model_validator(mode="before") + def set_padded_vocab_size(cls, values: dict[str, Any]): + """Set the padded vocab size to the next multiple of 64 if not provided.""" + vocab_size = values.get("vocab_size") + padded_vocab_size = values.get("padded_vocab_size") + + if padded_vocab_size is None and vocab_size is not None: + values["padded_vocab_size"] = find_multiple(vocab_size, 64) + return values + + @model_validator(mode="before") + def set_head_size(cls, values: dict[str, Any]): + head_size = values.get("head_size") + if head_size is None: + n_embd = values.get("n_embd") + n_head = values.get("n_head") + assert n_embd % n_head == 0 + values["head_size"] = n_embd // n_head + + return values + + @model_validator(mode="before") + def set_n_query_groups(cls, values: dict[str, Any]): + n_query_groups = values.get("n_query_groups") + n_head = values.get("n_head") + if n_query_groups is None: + values["n_query_groups"] = n_head + else: + assert n_head % n_query_groups == 0 + return values + + @model_validator(mode="before") + def set_intermediate_size(cls, values: dict[str, Any]): + intermediate_size = values.get("intermediate_size") + mlp_class_name = values.get("mlp_class_name") + if intermediate_size is None: + if mlp_class_name == "LLaMAMLP": + raise ValueError("The config needs to set the `intermediate_size`") + values["intermediate_size"] = 4 * values.get("n_embd") + return values + + @model_validator(mode="after") + def set_rope_n_elem(self): + self.rope_n_elem = int(self.rotary_percentage * self.head_size) + return self + + +class GPT(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + @property + def max_seq_length(self) -> int: + return self._max_seq_length + + @max_seq_length.setter + def max_seq_length(self, value: int) -> None: + """ + When doing inference, the sequences used might be shorter than the model's context length. + This allows setting a smaller number to avoid allocating unused memory + """ + if value > self.config.block_size: + raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}") + self._max_seq_length = value + if not hasattr(self, "cos"): + # first call + cos, sin = self.rope_cache() + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + # override + elif value != self.cos.size(0): + self.cos, self.sin = self.rope_cache(device=self.cos.device) + # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know + # if the kv cache is expected + + def reset_parameters(self) -> None: + # Trigger resetting the rope-cache + self.cos, self.sin = self.rope_cache(device=self.cos.device) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`.""" + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seqlens: Optional[Iterable[int]] = None + ) -> torch.Tensor: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + if self.config.scale_embeddings: + x = x * (self.config.n_embd**0.5) + + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos, seqlens) + x = self.transformer.ln_f(x) + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return build_rope_cache( + seq_len=self.max_seq_length, + n_elem=self.config.rope_n_elem, + device=device, + condense_ratio=self.config.rope_condense_ratio, + base=self.config.rope_base, + ) + + def set_kv_cache( + self, + batch_size: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + if rope_cache_length is None: + rope_cache_length = self.cos.size(-1) + max_seq_length = self.max_seq_length + + # initialize the kv cache for all blocks + for block in self.transformer.h: + block.attn.kv_cache = block.attn.build_kv_cache( + batch_size, max_seq_length, rope_cache_length, device, dtype + ) + + if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: + # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + self.mask_cache = build_mask_cache(max_seq_length, device) + + def clear_kv_cache(self) -> None: + self.mask_cache = None + for block in self.transformer.h: + block.attn.kv_cache = None + + +class Block(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + self.norm_2 = None if config.shared_attention_norm else RMSNorm(config.n_embd, eps=config.norm_eps) + self.mlp = LLaMAMLP(config) + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + seqlens: Optional[Iterable[int]] = None, + ) -> torch.Tensor: + n_1 = self.norm_1(x) + h = self.attn(n_1, cos, sin, mask, input_pos, seqlens) + if self.config.parallel_residual: + n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) + x = self.mlp(n_2) + h + x + else: + if self.config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + x = h + x + x = self.mlp(self.norm_2(x)) + x + return x + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) + # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` + self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + seqlens: Optional[Iterable[int]] = None, + ) -> torch.Tensor: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + qkv = self.attn(x) + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) + qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + + # maybe repeat k and v if for the non multi-head attention cases + # training: flash attention requires it + # inference: multi-query would require a full kv cache so avoid it to limit its memory usage + if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): + k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + + q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) + k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) + v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) + + q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + + if input_pos is not None: + if not isinstance(self.kv_cache, KVCache): + raise TypeError("You need to call `gpt.set_kv_cache()`") + k, v = self.kv_cache(input_pos, k, v) + + scale = 1.0 / math.sqrt(self.config.head_size) + y = self.scaled_dot_product_attention(q, k, v, scale, mask, seqlens) + y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side + + # output projection + return self.proj(y) + + def scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + mask: Optional[torch.Tensor] = None, + seqlens: Optional[Iterable[int]] = None, + ) -> torch.Tensor: + if seqlens is not None: + if mask is not None: + raise ValueError("context stuffing is not compatible with custom mask") + + if self.config.attention_impl == "sdpa": + raise ValueError("context stuffing is not supported with sdpa") + elif self.config.attention_impl == "xformers": + return self._xformers_attention_with_seqlens(q, k, v, scale, seqlens) + elif self.config.attention_impl == "fa": + return self._fa_attention_with_seqlens(q, k, v, scale, seqlens) + else: + raise ValueError(f"Unknown attention implementation: {self.config.attention_impl}") + + if self.config.attention_impl == "sdpa": + return self._sdpa_attention(q, k, v, scale, mask) + elif self.config.attention_impl == "xformers": + if XFORMERS_AVAILABLE: + return self._xformers_attention(q, k, v, scale, mask) + else: + raise ImportError("Xformers is not available, please install xformers library to use it.") + elif self.config.attention_impl == "fa": + if FLASH_AVAILABLE: + return self._fa_attention(q, k, v, scale, mask) + else: + raise ImportError("Flash attention is not available, please install flash_attn library to use it.") + else: + raise ValueError(f"Unknown attention implementation: {self.config.attention_impl}") + + def _sdpa_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None + ) + return y.transpose(1, 2) + + def _xformers_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if mask is not None: + raise ValueError("custom mask not supporting with xformers") + + attn_bias = xops.LowerTriangularMask() + + q = rearrange(q, "b n t h -> b t n h") + k = rearrange(k, "b n t h -> b t n h") + v = rearrange(v, "b n t h -> b t n h") + + return xops.memory_efficient_attention( + query=q, + key=k, + value=v, + scale=scale, + attn_bias=attn_bias, + op=xops.MemoryEfficientAttentionFlashAttentionOp, + ) + + def _xformers_attention_with_seqlens( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, seqlens: Iterable[int] + ): + attn_bias = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen=seqlens) + attn_bias = attn_bias.make_causal() + + q = rearrange(q, "b n t h -> 1 (b t) n h") + k = rearrange(k, "b n t h -> 1 (b t) n h") + v = rearrange(v, "b n t h -> 1 (b t) n h") + + return xops.memory_efficient_attention( + query=q, + key=k, + value=v, + scale=scale, + attn_bias=attn_bias, + op=xops.MemoryEfficientAttentionFlashAttentionOp, + ) + + def _fa_attention( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + q = rearrange(q, "b n t h -> b t n h") + k = rearrange(k, "b n t h -> b t n h") + v = rearrange(v, "b n t h -> b t n h") + # q/k/b is [b, nh, t, hs] but fa2 expected [b , t, nh, hs] + return flash_attn_func(q, k, v, causal=True, softmax_scale=scale) + + def _fa_attention_with_seqlens( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, scale: float, seqlens: Iterable[int] + ): + scale = 1.0 / math.sqrt(self.config.head_size) + b = q.shape[0] + seqlens = torch.tensor(seqlens, dtype=torch.int32) + cu_seqlens = torch.concat([torch.tensor([0]), seqlens.cumsum(0)], dim=0).to(torch.int32).to(q.device) + max_seqlen = seqlens.max().item() + + q = rearrange(q, "b n t h -> (b t) n h") + k = rearrange(k, "b n t h -> (b t) n h") + v = rearrange(v, "b n t h -> (b t) n h") + # q/k/v is [b, nh, t, hs] but fa expected [b * t, nh, hs] + y = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + softmax_scale=scale, + ) + + y = rearrange(y, "(b t) n h -> b t n h", b=b) + return y + + def build_kv_cache( + self, + batch_size: int, + max_seq_length: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> "KVCache": + heads = 1 if self.config.n_query_groups == 1 else self.config.n_head + v_shape = (batch_size, heads, max_seq_length, self.config.head_size) + if rope_cache_length is None: + if self.config.rotary_percentage != 1.0: + raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") + k_shape = v_shape + else: + k_shape = ( + batch_size, + heads, + max_seq_length, + rope_cache_length + self.config.head_size - self.config.rope_n_elem, + ) + return KVCache(k_shape, v_shape, device=device, dtype=dtype) + + +class LLaMAMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) + self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + return self.proj(x) + + +def build_rope_cache( + seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + + return torch.cos(idx_theta), torch.sin(idx_theta) + + +def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + head_size = x.size(-1) + x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) + x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + roped = (x * cos) + (rotated * sin) + return roped.to(dtype=x.dtype) + + +class KVCache(nn.Module): + def __init__( + self, + k_shape: Tuple[int, int, int, int], + v_shape: Tuple[int, int, int, int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) + self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) + + def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(k.dtype) + self.v = self.v.to(v.dtype) + # update the cache + k = self.k.index_copy_(2, input_pos, k) + v = self.v.index_copy_(2, input_pos, v) + return k, v + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.k) + torch.nn.init.zeros_(self.v) + + +def build_mask_cache(max_seq_length: int, device: Optional[torch.device] = None) -> torch.Tensor: + ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) + return torch.tril(ones).unsqueeze(0).unsqueeze(0) + + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__(self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + self.add_unit_offset = add_unit_offset + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x = x.float() + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + x_normed = x_normed.to(dtype=dtype) + if self.add_unit_offset: + # Gemma model requires a unit offset + # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176 + return x_normed * (1 + self.weight) + return x_normed * self.weight + + def reset_parameters(self) -> None: + torch.nn.init.ones_(self.weight) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 22f4696..84e8258 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -7,11 +7,13 @@ """ from functools import partial +import json import os import time from contextlib import nullcontext import datetime from typing import Any, Literal +from einops import rearrange import fsspec from pydantic import model_validator @@ -26,8 +28,6 @@ from transformers import ( AutoTokenizer, DataCollatorForLanguageModeling, - LlamaConfig, - LlamaForCausalLM, get_cosine_schedule_with_warmup, ) from torch.distributed.fsdp import ( @@ -39,7 +39,7 @@ from torch.distributed import broadcast_object_list from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer - +from open_diloco.llama import GPT, Config as ModelConfig from hivemind.dht.dht import DHT from hivemind.utils.networking import log_visible_maddrs @@ -114,7 +114,7 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: class Config(BaseConfig): - path_model: str = "PrimeIntellect/llama-150m-fresh" + llama_config: str | ModelConfig = "open_diloco/configs/config_1b.json" torch_compile: bool = True attn_implementation: str = "sdpa" # Data @@ -181,10 +181,14 @@ def tokenize_function(data): ) -def get_model(config: Config) -> LlamaForCausalLM: +def get_model(config: Config) -> GPT: # Load model - config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation) - return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) + if isinstance(config.llama_config, ModelConfig): + return GPT(config.llama_config) + else: + with open(config.llama_config) as f: + llama_config = ModelConfig(**json.load(f)) + return GPT(llama_config) def train(config: Config): @@ -389,9 +393,18 @@ def scheduler_fn(opt): batch[key] = batch[key].to("cuda") with model.no_sync() if is_accumulating else nullcontext(): - outputs = model(**batch) - loss = outputs.loss / gradient_accumulation_steps + inputs_ids = batch["input_ids"] + + input_ids = inputs_ids[:, :-1] + target = inputs_ids[:, 1:] + + output = model(input_ids) + + flatten_logits = rearrange(output, "b seq vocab -> (b seq) vocab") + flatten_target = rearrange(target, "b seq -> (b seq)") + loss = torch.nn.functional.cross_entropy(flatten_logits, flatten_target) + loss = loss / gradient_accumulation_steps loss_batch += loss.detach() scaler.scale(loss).backward() diff --git a/requirements.txt b/requirements.txt index e918dce..f5fba10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ fsspec[gcs]>=2024.3.1 torch>=2.3.1 hivemind @ git+https://github.com/learning-at-home/hivemind.git@213bff9 pydantic_config @ git+https://github.com/samsja/pydantic_config.git@8e19e05 - +xformers diff --git a/tests/test_llama.py b/tests/test_llama.py new file mode 100644 index 0000000..4733431 --- /dev/null +++ b/tests/test_llama.py @@ -0,0 +1,207 @@ +from typing import Tuple +import pytest +import torch + +from open_diloco.llama import GPT, CausalSelfAttention, Config, build_rope_cache +from xformers.ops.fmha.common import AttentionFwOpBase + + +@pytest.fixture +def config() -> Config: + return Config( + name="llama", + n_embd=64, + n_head=2, + n_layer=2, + vocab_size=1024, + ) + + +@pytest.mark.parametrize("attention_impl", ["sdpa", "xformers", "fa"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_gpt(config: Config, attention_impl: str, dtype: torch.dtype): + config.attention_impl = attention_impl + _test_gpt(config, dtype) + + +def _test_gpt(config: Config, dtype: torch.dtype): + device = torch.device("cuda") + model = GPT(config).to(dtype).to(device) + + BATCH_SIZE = 16 + SEQ_LEN = 8 + VOCAB_SIZE = 1024 + + input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN)).to(device) + + output = model(input) + + assert output is not None + assert not output.isnan().any() + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("attention_impl", ["xformers", "fa"]) +def test_gpt_output(config: Config, dtype: torch.dtype, attention_impl: str): + """ + in this test we compare the output of the GPT with sdpa and xformers/fa + """ + + device = torch.device("cuda") + model = GPT(config).to(dtype).to(device) + + BATCH_SIZE = 16 + SEQ_LEN = 8 + VOCAB_SIZE = 1024 + + input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN)).to(device) + + ### SDPA + config.attention_impl = "sdpa" + output_sdpa = model(input) + + ### XFORMERS + config.attention_impl = attention_impl + output_xformers = model(input) + + ### TESTING + assert output_sdpa.shape == output_xformers.shape + + ### xformers has a higher tolerance + atol = AttentionFwOpBase.ERROR_ATOL[dtype] + rtol = AttentionFwOpBase.ERROR_RTOL[dtype] + torch.testing.assert_close(output_sdpa, output_xformers, atol=atol, rtol=rtol) + + +def get_cos_and_sin_attn(config: Config, seq_len: int, device) -> Tuple[torch.Tensor, torch.Tensor]: + cos, sin = build_rope_cache( + seq_len=seq_len, + n_elem=config.rope_n_elem, + device=device, + condense_ratio=config.rope_condense_ratio, + base=config.rope_base, + ) + cos = cos[:seq_len] + sin = sin[:seq_len] + return cos, sin + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("attention_impl", ["xformers", "fa"]) +def test_attn_output(config: Config, dtype: torch.dtype, attention_impl: str): + """ + in this test we compare the output of the GPT with sdpa and xformers/fa + """ + + device = torch.device("cuda") + model = CausalSelfAttention(config).to(dtype).to(device) + + BATCH_SIZE = 16 + SEQ_LEN = 8 + + input = torch.rand(BATCH_SIZE, SEQ_LEN, config.n_embd).to(dtype).to(device) + cos, sin = get_cos_and_sin_attn(config, SEQ_LEN, device) + + ### SDPA + config.attention_impl = "sdpa" + output_sdpa = model(input, cos, sin) + + ### XFORMERS + config.attention_impl = attention_impl + output_xformers = model(input, cos, sin) + + ### TESTING + assert output_sdpa.shape == output_xformers.shape + + ### xformers has a higher tolerance + atol = AttentionFwOpBase.ERROR_ATOL[dtype] + rtol = AttentionFwOpBase.ERROR_RTOL[dtype] + torch.testing.assert_close(output_sdpa, output_xformers, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("attention_impl", ["xformers", "fa"]) +def test_context_stuffing_attn(config: Config, dtype: torch.dtype, attention_impl: str): + """ + In this test we compare normal pad attention with stuffing. + + input is [[2, 1, 4, 8, PAD, PAD, PAD, PAD], [1, 4, 2, 7, PAD, PAD, PAD, PAD]] + for padded input and [[2, 1, 4, 8, 1, 4, 2, 7]] for stuffed input + + we then compare the output of the two and should be the same. + """ + device = torch.device("cuda") + model = CausalSelfAttention(config).to(dtype).to(device) + + config.attention_impl = attention_impl + + SEQ_LEN = 8 + + emb = torch.nn.Embedding(10, config.n_embd).to(dtype).to(device) + + pad_id = 0 + input_raw = ( + torch.Tensor([[2, 1, 4, 8, pad_id, pad_id, pad_id, pad_id], [1, 4, 2, 7, pad_id, pad_id, pad_id, pad_id]]) + .long() + .to(device) + ) + input = emb(input_raw) + + input_stuff_raw = torch.Tensor([[2, 1, 4, 8, 1, 4, 2, 7]]).long().to(device) + seqlens = [4, 4] + input_stuff = emb(input_stuff_raw) + + cos, sin = get_cos_and_sin_attn(config, SEQ_LEN, device) + + ### batch + output_ctx_stuff = model(input, cos, sin) + + output_ctx_stuff = output_ctx_stuff[:, :4, :] # remove padding token + + output_xformers_stuff = model(input_stuff, cos, sin, seqlens=seqlens) + output_xformers_stuff = output_xformers_stuff.reshape(2, 4, config.n_embd) + + ### TESTING + assert output_ctx_stuff.shape == output_xformers_stuff.shape + + ### xformers has a higher tolerance + atol = AttentionFwOpBase.ERROR_ATOL[dtype] + rtol = AttentionFwOpBase.ERROR_RTOL[dtype] + torch.testing.assert_close(output_ctx_stuff, output_xformers_stuff, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("attention_impl", ["xformers", "fa"]) +def test_context_stuffing_attn_2(config: Config, dtype: torch.dtype, attention_impl: str): + """ + this test is slightu different from the one above, it tests + that passing two time the same input in a stuff way yield the same results. + """ + device = torch.device("cuda") + + config.attention_impl = attention_impl + model = CausalSelfAttention(config).to(dtype).to(device) + + SEQ_LEN = 8 + + emb = torch.nn.Embedding(10, config.n_embd).to(dtype).to(device) + + seq = [2, 1, 4, 8] + input_stuff_raw = torch.Tensor([seq + seq]).long().to(device) + seqlens = [len(seq), len(seq)] + input_stuff = emb(input_stuff_raw) + + cos, sin = get_cos_and_sin_attn(config, SEQ_LEN, device) + + output = model(input_stuff, cos, sin, seqlens=seqlens) + + output_left = output[:, :4, :] + output_right = output[:, 4:, :] + + ### TESTING + assert output_left.shape == output_right.shape + + ### xformers has a higher tolerance + atol = AttentionFwOpBase.ERROR_ATOL[dtype] + rtol = AttentionFwOpBase.ERROR_RTOL[dtype] + torch.testing.assert_close(output_left, output_right, atol=atol, rtol=rtol) diff --git a/tests/test_training/test_train.py b/tests/test_training/test_train.py index 12e532c..543cd4e 100644 --- a/tests/test_training/test_train.py +++ b/tests/test_training/test_train.py @@ -10,6 +10,7 @@ from hivemind.dht.dht import DHT from open_diloco.train_fsdp import train, Config, ddp_setup, destroy_process_group, HvConfig +from open_diloco.llama import Config as ModelConfig @pytest.fixture(autouse=True) @@ -46,8 +47,16 @@ def random_available_port(): @pytest.fixture def config() -> Config: + model_config = ModelConfig( + name="llama_2m", + n_embd=64, + intermediate_size=256, + n_head=2, + n_layer=2, + vocab_size=1024, + ) return Config( - path_model="tests/models/llama-2m-fresh", + llama_config=model_config, fake_data=True, torch_compile=False, lr=1e-2,