From 7fca3cab5d4120a6a111d4c126cf33148968c6ba Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:27:03 +0100 Subject: [PATCH 01/11] add init function to the builders --- src/fairseq2/models/jepa/factory.py | 121 +++++++++++++++++-- src/fairseq2/models/jepa/model.py | 17 ++- src/fairseq2/models/vit/feature_extractor.py | 29 +++++ src/fairseq2/nn/normalization.py | 18 ++- 4 files changed, 161 insertions(+), 24 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 0038299a2..329f2055a 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -6,10 +6,14 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass, field +from functools import partial +import math from typing import Final, cast -from torch.nn import GELU +import torch +from torch.nn import GELU, Module from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel @@ -23,6 +27,7 @@ from fairseq2.nn import ( InterpolatedPositionEncoder, LayerNorm, + Linear, Sinusoidal2dPositionEncoder, Sinusoidal3dPositionEncoder, StandardLayerNorm, @@ -39,6 +44,7 @@ TransformerNormOrder, create_default_sdpa, ) +from fairseq2.nn.transformer.residual import DropPathResidualConnect from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -96,10 +102,16 @@ class JepaEncoderConfig: The ratio of the dimensionality of the inner projection layers in feed-forward networks to :attr:`model_dim`. """ + + init_std: float = 0.02 + """std to initialize the weights and bias for linear and LayerNorm layers""" dropout_p: float = 0.0 """The dropout probability on outputs of Transformer layers.""" + droppath_p: float = 0.0 + """The probability of output sequence drop.""" + uniform_power: bool = False """ If ``True``, each patch dimension will have equal representation in the @@ -181,6 +193,8 @@ def build_frontend(self) -> TransformerFrontend: def build_feature_extractor(self) -> PatchFeatureExtractor: config = self._config + + conv_init_fn = partial(init_with_explicit_bounds, std=config.init_std) num_patch_dims = len(config.patch_dims) @@ -191,6 +205,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: config.num_input_channels, config.model_dim, patch_3d_dims, + init_fn=conv_init_fn, device=self._device, dtype=self._dtype, ) @@ -201,6 +216,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: config.num_input_channels, config.model_dim, patch_2d_dims, + init_fn=conv_init_fn, device=self._device, dtype=self._dtype, ) @@ -255,7 +271,7 @@ def build_encoder(self) -> TransformerEncoder: num_layers = config.num_encoder_layers - layers = [self.build_encoder_layer() for _ in range(num_layers)] + layers = [self.build_encoder_layer(i) for i in range(num_layers)] return StandardTransformerEncoder( layers, @@ -265,12 +281,14 @@ def build_encoder(self) -> TransformerEncoder: dtype=self._dtype, ) - def build_encoder_layer(self) -> TransformerEncoderLayer: + def build_encoder_layer(self, layer_id: int) -> TransformerEncoderLayer: config = self._config - self_attn = self.build_attention() + self_attn = self.build_attention(layer_id) - ffn = self.build_ffn() + ffn = self.build_ffn(layer_id) + + drop_path = DropPathResidualConnect(drop_p=config.droppath_p) return StandardTransformerEncoderLayer( self_attn, @@ -278,47 +296,87 @@ def build_encoder_layer(self) -> TransformerEncoderLayer: dropout_p=config.dropout_p, norm_order=TransformerNormOrder.PRE, layer_norm_factory=self.build_layer_norm, + self_attn_residual=drop_path, + ffn_residual=drop_path, device=self._device, dtype=self._dtype, ) - def build_attention(self) -> MultiheadAttention: + def build_attention(self, layer_id: int) -> MultiheadAttention: config = self._config sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) + proj = self.build_projection(layer_id) + return StandardMultiheadAttention( config.model_dim, config.num_encoder_attn_heads, sdpa=sdpa, bias=config.qkv_bias, + output_proj=proj, output_proj_bias=True, device=self._device, dtype=self._dtype, ) - def build_ffn(self) -> FeedForwardNetwork: + def build_projection(self, layer_id: int) -> Linear: + config = self._config + + proj_init_fn: Callable[[Linear], None] = partial( + init_with_explicit_bounds, std=config.init_std + ) + + proj = Linear( + config.model_dim, + config.model_dim, + bias=True, + init_fn=proj_init_fn, + device=self._device, + dtype=self._dtype, + ) + + # rescale the linear layer + proj.weight.data.div_(math.sqrt(2.0 * layer_id)) + + return proj + + def build_ffn(self, layer_id: int) -> FeedForwardNetwork: config = self._config - return StandardFeedForwardNetwork( + proj_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + + ffn = StandardFeedForwardNetwork( config.model_dim, int(config.model_dim * config.ffn_inner_dim_ratio), bias=True, inner_activation=GELU(), + proj_init_fn=proj_init_fn, norm_order=TransformerNormOrder.PRE, device=self._device, dtype=self._dtype, ) - @staticmethod + # rescale the last layer + proj = ffn.output_proj + assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" + proj.weight.data.div_(math.sqrt(2.0 * layer_id)) + + return ffn + def build_layer_norm( + self, model_dim: int, *, device: Device | None = None, dtype: DataType | None = None, ) -> LayerNorm: + config = self._config + + layer_norm_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + return StandardLayerNorm( - model_dim, bias=True, eps=1e-6, device=device, dtype=dtype + model_dim, bias=True, eps=1e-6, init_fn=layer_norm_init_fn, device=device, dtype=dtype ) @@ -329,3 +387,46 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() + + +def _norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + +def normalize_truncate( + tensor: torch.Tensor, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + + lower = _norm_cdf((a - mean) / std) + upper = _norm_cdf((b - mean) / std) + + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + tensor.erfinv_() + + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + tensor.clamp_(min=a, max=b) + + +def init_with_explicit_bounds( + m: Module, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +): + if not hasattr(m, "weight") or not hasattr(m, "bias"): + raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") + + with torch.no_grad(): + normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) diff --git a/src/fairseq2/models/jepa/model.py b/src/fairseq2/models/jepa/model.py index 8312901e0..281e8187b 100644 --- a/src/fairseq2/models/jepa/model.py +++ b/src/fairseq2/models/jepa/model.py @@ -6,7 +6,6 @@ from __future__ import annotations -from dataclasses import dataclass from typing import final from torch.nn import Module @@ -43,11 +42,11 @@ def __init__( self.encoder_frontend = encoder_frontend self.encoder = encoder - def forward(self, batch: SequenceBatch) -> JepaOutput: - raise NotImplementedError() - - -@final -@dataclass -class JepaOutput: - pass + def forward(self, batch: SequenceBatch) -> SequenceBatch: + seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) + out_seqs, out_mask = self.encoder(seqs, padding_mask) # type: ignore[no-any-return] + + return SequenceBatch( + seqs=out_seqs, + padding_mask=out_mask, + ) diff --git a/src/fairseq2/models/vit/feature_extractor.py b/src/fairseq2/models/vit/feature_extractor.py index dce324d40..b3ec35033 100644 --- a/src/fairseq2/models/vit/feature_extractor.py +++ b/src/fairseq2/models/vit/feature_extractor.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from typing import final from torch import Tensor @@ -54,6 +55,7 @@ class Conv2dPatchFeatureExtractor(PatchFeatureExtractor): """Extracts patch features from 2-dimensional inputs using convolution.""" conv: Conv2d + init_fn: Callable[[Conv2d], None] | None def __init__( self, @@ -61,6 +63,7 @@ def __init__( feature_dim: int, patch_dims: tuple[int, int], *, + init_fn: Callable[[Conv2d], None] | None = None, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -81,6 +84,18 @@ def __init__( dtype=dtype, ) + self.init_fn = init_fn + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + if self.init_fn is not None: + self.init_fn(self.conv) + else: + self.conv.reset_parameters() + + @override def forward(self, x: Tensor) -> Tensor: # (N, C, H_inp, W_inp) -> (N, H_out, W_out, E) @@ -92,6 +107,7 @@ class Conv3dPatchFeatureExtractor(PatchFeatureExtractor): """Extracts patch features from 3-dimensional inputs using convolution.""" conv: Conv3d + init_fn: Callable[[Conv3d], None] | None def __init__( self, @@ -99,6 +115,7 @@ def __init__( feature_dim: int, patch_dims: tuple[int, int, int], *, + init_fn: Callable[[Conv2d], None] | None = None, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -118,6 +135,18 @@ def __init__( device=device, dtype=dtype, ) + + self.init_fn = init_fn + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + if self.init_fn is not None: + self.init_fn(self.conv) + else: + self.conv.reset_parameters() + @override def forward(self, x: Tensor) -> Tensor: diff --git a/src/fairseq2/nn/normalization.py b/src/fairseq2/nn/normalization.py index 44fe37287..64656507e 100644 --- a/src/fairseq2/nn/normalization.py +++ b/src/fairseq2/nn/normalization.py @@ -7,7 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any, Literal, final import torch @@ -39,6 +39,8 @@ class LayerNorm(Module, ABC): elementwise_affine: bool weight: Parameter | None bias: Parameter | None + init_fn: Callable[[LayerNorm], None] | None + def __init__( self, @@ -47,6 +49,7 @@ def __init__( *, eps: float = 1e-5, elementwise_affine: bool = True, + init_fn: Callable[[LayerNorm], None] | None = None, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -87,16 +90,21 @@ def __init__( ) else: self.register_parameter("bias", None) + + self.init_fn = init_fn self.reset_parameters() def reset_parameters(self) -> None: """Reset the parameters and buffers of the module.""" - if self.weight is not None: - nn.init.ones_(self.weight) + if self.init_fn is not None: + self.init_fn(self) + else: + if self.weight is not None: + nn.init.ones_(self.weight) - if self.bias is not None: - nn.init.zeros_(self.bias) + if self.bias is not None: + nn.init.zeros_(self.bias) @abstractmethod def forward(self, x: Tensor) -> Tensor: From 7b959fe21e9fe83dbb68e79c7f46fd7704392141 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:58:52 +0100 Subject: [PATCH 02/11] refactor init_module function --- src/fairseq2/models/jepa/factory.py | 57 +++-------------------------- src/fairseq2/nn/utils/module.py | 43 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 329f2055a..91059ef50 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -12,8 +12,7 @@ import math from typing import Final, cast -import torch -from torch.nn import GELU, Module +from torch.nn import GELU from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel @@ -45,6 +44,7 @@ create_default_sdpa, ) from fairseq2.nn.transformer.residual import DropPathResidualConnect +from fairseq2.nn.utils.module import init_truncated_uniforma_weights_and_bias as init_module from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -194,7 +194,7 @@ def build_frontend(self) -> TransformerFrontend: def build_feature_extractor(self) -> PatchFeatureExtractor: config = self._config - conv_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + conv_init_fn = partial(init_module, std=config.init_std) num_patch_dims = len(config.patch_dims) @@ -323,9 +323,7 @@ def build_attention(self, layer_id: int) -> MultiheadAttention: def build_projection(self, layer_id: int) -> Linear: config = self._config - proj_init_fn: Callable[[Linear], None] = partial( - init_with_explicit_bounds, std=config.init_std - ) + proj_init_fn: Callable[[Linear], None] = partial(init_module, std=config.init_std) proj = Linear( config.model_dim, @@ -344,7 +342,7 @@ def build_projection(self, layer_id: int) -> Linear: def build_ffn(self, layer_id: int) -> FeedForwardNetwork: config = self._config - proj_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + proj_init_fn = partial(init_module, std=config.init_std) ffn = StandardFeedForwardNetwork( config.model_dim, @@ -373,7 +371,7 @@ def build_layer_norm( ) -> LayerNorm: config = self._config - layer_norm_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + layer_norm_init_fn = partial(init_module, std=config.init_std) return StandardLayerNorm( model_dim, bias=True, eps=1e-6, init_fn=layer_norm_init_fn, device=device, dtype=dtype @@ -387,46 +385,3 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() - - -def _norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - -def normalize_truncate( - tensor: torch.Tensor, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - - lower = _norm_cdf((a - mean) / std) - upper = _norm_cdf((b - mean) / std) - - tensor.uniform_(2 * lower - 1, 2 * upper - 1) - tensor.erfinv_() - - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - tensor.clamp_(min=a, max=b) - - -def init_with_explicit_bounds( - m: Module, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -): - if not hasattr(m, "weight") or not hasattr(m, "bias"): - raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - - with torch.no_grad(): - normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) - if m.bias is not None: - torch.nn.init.zeros_(m.bias) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index e76200599..fd464eefa 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -6,6 +6,7 @@ from __future__ import annotations +import math import re from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass @@ -570,3 +571,45 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info + + +def normalize_truncate( + tensor: Tensor, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + + def _norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + lower = _norm_cdf((a - mean) / std) + upper = _norm_cdf((b - mean) / std) + + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + tensor.erfinv_() + + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + tensor.clamp_(min=a, max=b) + + +def init_truncated_uniforma_weights_and_bias( + m: Module, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +): + if not hasattr(m, "weight") or not hasattr(m, "bias"): + raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") + + with torch.no_grad(): + normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) From f745beabfe350770855dcfab514f0ed18c6eff22 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Thu, 19 Dec 2024 15:22:13 +0000 Subject: [PATCH 03/11] Cosmetic updates --- src/fairseq2/models/jepa/factory.py | 102 +++++++++++-------- src/fairseq2/models/jepa/model.py | 10 +- src/fairseq2/models/vit/feature_extractor.py | 6 +- src/fairseq2/nn/normalization.py | 3 +- src/fairseq2/nn/utils/module.py | 15 ++- 5 files changed, 74 insertions(+), 62 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 91059ef50..48cffbd22 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -6,12 +6,12 @@ from __future__ import annotations -from collections.abc import Callable +import math from dataclasses import dataclass, field from functools import partial -import math from typing import Final, cast +import torch from torch.nn import GELU from fairseq2.config_registry import ConfigRegistry @@ -44,7 +44,9 @@ create_default_sdpa, ) from fairseq2.nn.transformer.residual import DropPathResidualConnect -from fairseq2.nn.utils.module import init_truncated_uniforma_weights_and_bias as init_module +from fairseq2.nn.utils.module import ( + init_truncated_uniforma_weights_and_bias as init_module, +) from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -102,15 +104,21 @@ class JepaEncoderConfig: The ratio of the dimensionality of the inner projection layers in feed-forward networks to :attr:`model_dim`. """ - + init_std: float = 0.02 - """std to initialize the weights and bias for linear and LayerNorm layers""" + """ + The standard deviation to initialize weights and biases of projection and + normalization layers. + """ dropout_p: float = 0.0 """The dropout probability on outputs of Transformer layers.""" droppath_p: float = 0.0 - """The probability of output sequence drop.""" + """ + The probability of dropping sequences from outputs of multi-head attention + and feed-forward network layers before adding residuals. + """ uniform_power: bool = False """ @@ -193,8 +201,10 @@ def build_frontend(self) -> TransformerFrontend: def build_feature_extractor(self) -> PatchFeatureExtractor: config = self._config - - conv_init_fn = partial(init_module, std=config.init_std) + + init_std = config.init_std + + init_conv = partial(init_module, std=init_std) num_patch_dims = len(config.patch_dims) @@ -205,7 +215,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: config.num_input_channels, config.model_dim, patch_3d_dims, - init_fn=conv_init_fn, + init_fn=init_conv, device=self._device, dtype=self._dtype, ) @@ -216,7 +226,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: config.num_input_channels, config.model_dim, patch_2d_dims, - init_fn=conv_init_fn, + init_fn=init_conv, device=self._device, dtype=self._dtype, ) @@ -281,13 +291,13 @@ def build_encoder(self) -> TransformerEncoder: dtype=self._dtype, ) - def build_encoder_layer(self, layer_id: int) -> TransformerEncoderLayer: + def build_encoder_layer(self, layer_idx: int) -> TransformerEncoderLayer: config = self._config - self_attn = self.build_attention(layer_id) + self_attn = self.build_attention(layer_idx) + + ffn = self.build_ffn(layer_idx) - ffn = self.build_ffn(layer_id) - drop_path = DropPathResidualConnect(drop_p=config.droppath_p) return StandardTransformerEncoderLayer( @@ -302,66 +312,67 @@ def build_encoder_layer(self, layer_id: int) -> TransformerEncoderLayer: dtype=self._dtype, ) - def build_attention(self, layer_id: int) -> MultiheadAttention: + def build_attention(self, layer_idx: int) -> MultiheadAttention: config = self._config sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) - proj = self.build_projection(layer_id) + output_proj = self.build_mha_output_projection(layer_idx) return StandardMultiheadAttention( config.model_dim, config.num_encoder_attn_heads, sdpa=sdpa, bias=config.qkv_bias, - output_proj=proj, - output_proj_bias=True, + output_proj=output_proj, device=self._device, dtype=self._dtype, ) - def build_projection(self, layer_id: int) -> Linear: + def build_mha_output_projection(self, layer_idx: int) -> Linear: config = self._config - proj_init_fn: Callable[[Linear], None] = partial(init_module, std=config.init_std) + init_std = config.init_std - proj = Linear( + def init_projection(proj: Linear) -> None: + init_module(proj, std=init_std) + + with torch.no_grad(): + proj.weight.div_(math.sqrt(2.0 * layer_idx)) + + return Linear( config.model_dim, config.model_dim, bias=True, - init_fn=proj_init_fn, + init_fn=init_projection, device=self._device, dtype=self._dtype, ) - # rescale the linear layer - proj.weight.data.div_(math.sqrt(2.0 * layer_id)) + def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: + config = self._config - return proj + init_std = config.init_std - def build_ffn(self, layer_id: int) -> FeedForwardNetwork: - config = self._config + def init_projection(proj: Linear) -> None: + init_module(proj, std=init_std) + + with torch.no_grad(): + proj.weight.div_(math.sqrt(2.0 * layer_idx)) - proj_init_fn = partial(init_module, std=config.init_std) + inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) - ffn = StandardFeedForwardNetwork( + return StandardFeedForwardNetwork( config.model_dim, - int(config.model_dim * config.ffn_inner_dim_ratio), + inner_dim, bias=True, inner_activation=GELU(), - proj_init_fn=proj_init_fn, + proj_init_fn=init_projection, norm_order=TransformerNormOrder.PRE, device=self._device, dtype=self._dtype, ) - # rescale the last layer - proj = ffn.output_proj - assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" - proj.weight.data.div_(math.sqrt(2.0 * layer_id)) - - return ffn - def build_layer_norm( self, model_dim: int, @@ -370,11 +381,18 @@ def build_layer_norm( dtype: DataType | None = None, ) -> LayerNorm: config = self._config - - layer_norm_init_fn = partial(init_module, std=config.init_std) - + + init_std = config.init_std + + init_layer_norm = partial(init_module, std=init_std) + return StandardLayerNorm( - model_dim, bias=True, eps=1e-6, init_fn=layer_norm_init_fn, device=device, dtype=dtype + model_dim, + bias=True, + eps=1e-6, + init_fn=init_layer_norm, + device=device, + dtype=dtype, ) diff --git a/src/fairseq2/models/jepa/model.py b/src/fairseq2/models/jepa/model.py index 281e8187b..b1413328a 100644 --- a/src/fairseq2/models/jepa/model.py +++ b/src/fairseq2/models/jepa/model.py @@ -44,9 +44,7 @@ def __init__( def forward(self, batch: SequenceBatch) -> SequenceBatch: seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) - out_seqs, out_mask = self.encoder(seqs, padding_mask) # type: ignore[no-any-return] - - return SequenceBatch( - seqs=out_seqs, - padding_mask=out_mask, - ) + + seqs, padding_mask = self.encoder(seqs, padding_mask) + + return SequenceBatch(seqs, padding_mask) diff --git a/src/fairseq2/models/vit/feature_extractor.py b/src/fairseq2/models/vit/feature_extractor.py index b3ec35033..50747eaaf 100644 --- a/src/fairseq2/models/vit/feature_extractor.py +++ b/src/fairseq2/models/vit/feature_extractor.py @@ -95,7 +95,6 @@ def reset_parameters(self) -> None: else: self.conv.reset_parameters() - @override def forward(self, x: Tensor) -> Tensor: # (N, C, H_inp, W_inp) -> (N, H_out, W_out, E) @@ -115,7 +114,7 @@ def __init__( feature_dim: int, patch_dims: tuple[int, int, int], *, - init_fn: Callable[[Conv2d], None] | None = None, + init_fn: Callable[[Conv3d], None] | None = None, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -135,7 +134,7 @@ def __init__( device=device, dtype=dtype, ) - + self.init_fn = init_fn self.reset_parameters() @@ -147,7 +146,6 @@ def reset_parameters(self) -> None: else: self.conv.reset_parameters() - @override def forward(self, x: Tensor) -> Tensor: # (N, C, D_inp, H_inp, W_inp) -> (N, D_out, H_out, W_out, E) diff --git a/src/fairseq2/nn/normalization.py b/src/fairseq2/nn/normalization.py index 64656507e..c371bc156 100644 --- a/src/fairseq2/nn/normalization.py +++ b/src/fairseq2/nn/normalization.py @@ -41,7 +41,6 @@ class LayerNorm(Module, ABC): bias: Parameter | None init_fn: Callable[[LayerNorm], None] | None - def __init__( self, normalized_shape: int | Sequence[int] | Size, @@ -90,7 +89,7 @@ def __init__( ) else: self.register_parameter("bias", None) - + self.init_fn = init_fn self.reset_parameters() diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index fd464eefa..ddf4c7bf4 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -581,22 +581,21 @@ def normalize_truncate( a: float = -2.0, b: float = 2.0, ) -> None: - - def _norm_cdf(x): + def _norm_cdf(x: float) -> float: # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 lower = _norm_cdf((a - mean) / std) upper = _norm_cdf((b - mean) / std) - + tensor.uniform_(2 * lower - 1, 2 * upper - 1) tensor.erfinv_() - + tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) - + tensor.clamp_(min=a, max=b) - + def init_truncated_uniforma_weights_and_bias( m: Module, @@ -605,10 +604,10 @@ def init_truncated_uniforma_weights_and_bias( std: float = 1.0, a: float = -2.0, b: float = 2.0, -): +) -> None: if not hasattr(m, "weight") or not hasattr(m, "bias"): raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - + with torch.no_grad(): normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: From 854b68e1304f5a607842c8b9741df13ab3fe919a Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:41:23 +0100 Subject: [PATCH 04/11] Can's comments --- src/fairseq2/models/jepa/factory.py | 32 ++++++++++++++++++++++++++++- src/fairseq2/nn/utils/module.py | 4 +++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 48cffbd22..7b51ab8af 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -11,8 +11,14 @@ from functools import partial from typing import Final, cast +<<<<<<< HEAD import torch from torch.nn import GELU +======= +from torch import Tensor +import torch +from torch.nn import GELU, Module +>>>>>>> a09ad4fe (Can's comments) from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel @@ -338,7 +344,7 @@ def init_projection(proj: Linear) -> None: init_module(proj, std=init_std) with torch.no_grad(): - proj.weight.div_(math.sqrt(2.0 * layer_idx)) + proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) return Linear( config.model_dim, @@ -373,6 +379,13 @@ def init_projection(proj: Linear) -> None: dtype=self._dtype, ) + # rescale the last layer + proj = ffn.output_proj + assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" + proj.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) + + return ffn + def build_layer_norm( self, model_dim: int, @@ -403,3 +416,20 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() + + +def init_truncated_uniforma_weights_and_bias( + m: Module, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +): + if not hasattr(m, "weight") or not hasattr(m, "bias"): + raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") + + with torch.no_grad(): + torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index ddf4c7bf4..baa240d85 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -6,7 +6,6 @@ from __future__ import annotations -import math import re from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass @@ -571,6 +570,7 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info +<<<<<<< HEAD def normalize_truncate( @@ -612,3 +612,5 @@ def init_truncated_uniforma_weights_and_bias( normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: torch.nn.init.zeros_(m.bias) +======= +>>>>>>> a09ad4fe (Can's comments) From 328f8ca1256d1b39a92abd8fd7bf20cd3c4e7847 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:43:17 +0100 Subject: [PATCH 05/11] fix git rebase --- src/fairseq2/models/jepa/factory.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 7b51ab8af..5d6dbc79b 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -11,14 +11,8 @@ from functools import partial from typing import Final, cast -<<<<<<< HEAD -import torch -from torch.nn import GELU -======= -from torch import Tensor import torch from torch.nn import GELU, Module ->>>>>>> a09ad4fe (Can's comments) from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel From f57108d96243605a9c0779b3e76b4c0d8ce29153 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:45:43 +0100 Subject: [PATCH 06/11] fix git rebase --- src/fairseq2/nn/utils/module.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index baa240d85..981ab0ac5 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -570,31 +570,6 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info -<<<<<<< HEAD - - -def normalize_truncate( - tensor: Tensor, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - def _norm_cdf(x: float) -> float: - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - lower = _norm_cdf((a - mean) / std) - upper = _norm_cdf((b - mean) / std) - - tensor.uniform_(2 * lower - 1, 2 * upper - 1) - tensor.erfinv_() - - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - tensor.clamp_(min=a, max=b) def init_truncated_uniforma_weights_and_bias( @@ -609,8 +584,6 @@ def init_truncated_uniforma_weights_and_bias( raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") with torch.no_grad(): - normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) + torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: torch.nn.init.zeros_(m.bias) -======= ->>>>>>> a09ad4fe (Can's comments) From 2d27bab861b1d66bb85386ba3e179741ce72f741 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:54:01 +0100 Subject: [PATCH 07/11] lint --- src/fairseq2/models/jepa/factory.py | 17 +++++++---------- src/fairseq2/nn/utils/module.py | 17 ----------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 5d6dbc79b..88ae20283 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -44,9 +44,6 @@ create_default_sdpa, ) from fairseq2.nn.transformer.residual import DropPathResidualConnect -from fairseq2.nn.utils.module import ( - init_truncated_uniforma_weights_and_bias as init_module, -) from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -204,7 +201,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: init_std = config.init_std - init_conv = partial(init_module, std=init_std) + init_conv = partial(init_truncated_uniforma_weights_and_bias, std=init_std) num_patch_dims = len(config.patch_dims) @@ -335,7 +332,7 @@ def build_mha_output_projection(self, layer_idx: int) -> Linear: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_module(proj, std=init_std) + init_truncated_uniforma_weights_and_bias(proj, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) @@ -355,14 +352,14 @@ def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_module(proj, std=init_std) + init_truncated_uniforma_weights_and_bias(proj, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * layer_idx)) inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) - return StandardFeedForwardNetwork( + ffn = StandardFeedForwardNetwork( config.model_dim, inner_dim, bias=True, @@ -376,7 +373,7 @@ def init_projection(proj: Linear) -> None: # rescale the last layer proj = ffn.output_proj assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" - proj.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) + proj.weight.data.div_(math.sqrt(2.0 * (layer_idx + 1))) return ffn @@ -391,7 +388,7 @@ def build_layer_norm( init_std = config.init_std - init_layer_norm = partial(init_module, std=init_std) + init_layer_norm = partial(init_truncated_uniforma_weights_and_bias, std=init_std) return StandardLayerNorm( model_dim, @@ -419,7 +416,7 @@ def init_truncated_uniforma_weights_and_bias( std: float = 1.0, a: float = -2.0, b: float = 2.0, -): +) -> None: if not hasattr(m, "weight") or not hasattr(m, "bias"): raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index 981ab0ac5..e76200599 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -570,20 +570,3 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info - - -def init_truncated_uniforma_weights_and_bias( - m: Module, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - if not hasattr(m, "weight") or not hasattr(m, "bias"): - raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - - with torch.no_grad(): - torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) - if m.bias is not None: - torch.nn.init.zeros_(m.bias) From 8d7dfafedad4fbfbe229e09ce5567727731cf3bb Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:55:47 +0100 Subject: [PATCH 08/11] lint --- src/fairseq2/models/jepa/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 88ae20283..006715928 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -355,7 +355,7 @@ def init_projection(proj: Linear) -> None: init_truncated_uniforma_weights_and_bias(proj, std=init_std) with torch.no_grad(): - proj.weight.div_(math.sqrt(2.0 * layer_idx)) + proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) From e86afaa3fa4737165cbcad43afd0126ff4464cff Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:02:33 +0100 Subject: [PATCH 09/11] flake8 --- src/fairseq2/models/jepa/factory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 006715928..a06bc8c74 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -374,7 +374,7 @@ def init_projection(proj: Linear) -> None: proj = ffn.output_proj assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" proj.weight.data.div_(math.sqrt(2.0 * (layer_idx + 1))) - + return ffn def build_layer_norm( @@ -407,7 +407,7 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() - + def init_truncated_uniforma_weights_and_bias( m: Module, @@ -419,7 +419,7 @@ def init_truncated_uniforma_weights_and_bias( ) -> None: if not hasattr(m, "weight") or not hasattr(m, "bias"): raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - + with torch.no_grad(): torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: From f4aaf33c5ed7e41130560a1848de7779ff8d34ca Mon Sep 17 00:00:00 2001 From: tuantran user Date: Thu, 19 Dec 2024 17:25:39 +0000 Subject: [PATCH 10/11] black --- src/fairseq2/models/jepa/factory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index a06bc8c74..63fe44e71 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -388,7 +388,9 @@ def build_layer_norm( init_std = config.init_std - init_layer_norm = partial(init_truncated_uniforma_weights_and_bias, std=init_std) + init_layer_norm = partial( + init_truncated_uniforma_weights_and_bias, std=init_std + ) return StandardLayerNorm( model_dim, From 8ecc54421e0a630d0e70fc523a89476c96f691da Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Fri, 20 Dec 2024 21:00:24 +0000 Subject: [PATCH 11/11] Remove remnant code --- src/fairseq2/models/jepa/factory.py | 47 ++++++++++------------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 63fe44e71..d654e32a4 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -12,6 +12,7 @@ from typing import Final, cast import torch +import torch.nn as nn from torch.nn import GELU, Module from fairseq2.config_registry import ConfigRegistry @@ -201,7 +202,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: init_std = config.init_std - init_conv = partial(init_truncated_uniforma_weights_and_bias, std=init_std) + init_conv = partial(init_truncated_normal, std=init_std) num_patch_dims = len(config.patch_dims) @@ -332,7 +333,7 @@ def build_mha_output_projection(self, layer_idx: int) -> Linear: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_truncated_uniforma_weights_and_bias(proj, std=init_std) + init_truncated_normal(proj, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) @@ -352,14 +353,14 @@ def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_truncated_uniforma_weights_and_bias(proj, std=init_std) + init_truncated_normal(proj, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) - ffn = StandardFeedForwardNetwork( + return StandardFeedForwardNetwork( config.model_dim, inner_dim, bias=True, @@ -370,13 +371,6 @@ def init_projection(proj: Linear) -> None: dtype=self._dtype, ) - # rescale the last layer - proj = ffn.output_proj - assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" - proj.weight.data.div_(math.sqrt(2.0 * (layer_idx + 1))) - - return ffn - def build_layer_norm( self, model_dim: int, @@ -388,9 +382,7 @@ def build_layer_norm( init_std = config.init_std - init_layer_norm = partial( - init_truncated_uniforma_weights_and_bias, std=init_std - ) + init_layer_norm = partial(init_truncated_normal, std=init_std) return StandardLayerNorm( model_dim, @@ -402,6 +394,16 @@ def build_layer_norm( ) +def init_truncated_normal(module: Module, *, std: float = 1.0) -> None: + if not hasattr(module, "weight"): + raise ValueError("`module` does not have a parameter with name `weight`.") + + nn.init.trunc_normal_(module.weight, std=std) + + if hasattr(module, "bias") and module.bias is not None: + nn.init.zeros_(module.bias) + + def create_jepa_model( config: JepaConfig, *, @@ -409,20 +411,3 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() - - -def init_truncated_uniforma_weights_and_bias( - m: Module, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - if not hasattr(m, "weight") or not hasattr(m, "bias"): - raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - - with torch.no_grad(): - torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) - if m.bias is not None: - torch.nn.init.zeros_(m.bias)