From 8ecc54421e0a630d0e70fc523a89476c96f691da Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Fri, 20 Dec 2024 21:00:24 +0000 Subject: [PATCH] Remove remnant code --- src/fairseq2/models/jepa/factory.py | 47 ++++++++++------------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 63fe44e71..d654e32a4 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -12,6 +12,7 @@ from typing import Final, cast import torch +import torch.nn as nn from torch.nn import GELU, Module from fairseq2.config_registry import ConfigRegistry @@ -201,7 +202,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: init_std = config.init_std - init_conv = partial(init_truncated_uniforma_weights_and_bias, std=init_std) + init_conv = partial(init_truncated_normal, std=init_std) num_patch_dims = len(config.patch_dims) @@ -332,7 +333,7 @@ def build_mha_output_projection(self, layer_idx: int) -> Linear: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_truncated_uniforma_weights_and_bias(proj, std=init_std) + init_truncated_normal(proj, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) @@ -352,14 +353,14 @@ def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_truncated_uniforma_weights_and_bias(proj, std=init_std) + init_truncated_normal(proj, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) - ffn = StandardFeedForwardNetwork( + return StandardFeedForwardNetwork( config.model_dim, inner_dim, bias=True, @@ -370,13 +371,6 @@ def init_projection(proj: Linear) -> None: dtype=self._dtype, ) - # rescale the last layer - proj = ffn.output_proj - assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" - proj.weight.data.div_(math.sqrt(2.0 * (layer_idx + 1))) - - return ffn - def build_layer_norm( self, model_dim: int, @@ -388,9 +382,7 @@ def build_layer_norm( init_std = config.init_std - init_layer_norm = partial( - init_truncated_uniforma_weights_and_bias, std=init_std - ) + init_layer_norm = partial(init_truncated_normal, std=init_std) return StandardLayerNorm( model_dim, @@ -402,6 +394,16 @@ def build_layer_norm( ) +def init_truncated_normal(module: Module, *, std: float = 1.0) -> None: + if not hasattr(module, "weight"): + raise ValueError("`module` does not have a parameter with name `weight`.") + + nn.init.trunc_normal_(module.weight, std=std) + + if hasattr(module, "bias") and module.bias is not None: + nn.init.zeros_(module.bias) + + def create_jepa_model( config: JepaConfig, *, @@ -409,20 +411,3 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() - - -def init_truncated_uniforma_weights_and_bias( - m: Module, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - if not hasattr(m, "weight") or not hasattr(m, "bias"): - raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - - with torch.no_grad(): - torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) - if m.bias is not None: - torch.nn.init.zeros_(m.bias)