From 271295a0f76701cdf64217db38a458e781f12670 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Tue, 3 Oct 2023 13:46:33 -0700 Subject: [PATCH] Refactor Embedding --- src/fairseq2/models/llama/builder.py | 4 +- src/fairseq2/models/nllb/builder.py | 4 +- .../models/s2t_transformer/builder.py | 4 +- src/fairseq2/nn/embedding.py | 78 +++++++++++++------ 4 files changed, 61 insertions(+), 29 deletions(-) diff --git a/src/fairseq2/models/llama/builder.py b/src/fairseq2/models/llama/builder.py index 584a9e436..8cefd89fa 100644 --- a/src/fairseq2/models/llama/builder.py +++ b/src/fairseq2/models/llama/builder.py @@ -14,7 +14,7 @@ TransformerFrontend, ) from fairseq2.models.utils.arch_registry import ArchitectureRegistry -from fairseq2.nn.embedding import Embedding +from fairseq2.nn.embedding import StandardEmbedding from fairseq2.nn.normalization import LayerNorm, RMSNorm from fairseq2.nn.position_encoder import RotaryEncoder from fairseq2.nn.transformer import ( @@ -240,7 +240,7 @@ def build_model(self) -> TransformerDecoderModel: def build_frontend(self) -> TransformerFrontend: """Build a Transformer decoder front-end.""" - embed = Embedding( + embed = StandardEmbedding( num_embeddings=self.config.vocabulary_size, embedding_dim=self.config.model_dim, device=self.device, diff --git a/src/fairseq2/models/nllb/builder.py b/src/fairseq2/models/nllb/builder.py index 499a86228..172f550fd 100644 --- a/src/fairseq2/models/nllb/builder.py +++ b/src/fairseq2/models/nllb/builder.py @@ -14,7 +14,7 @@ TransformerModel, ) from fairseq2.models.utils.arch_registry import ArchitectureRegistry -from fairseq2.nn.embedding import Embedding +from fairseq2.nn.embedding import Embedding, StandardEmbedding from fairseq2.nn.position_encoder import SinusoidalPositionEncoder from fairseq2.nn.projection import TiedProjection from fairseq2.nn.transformer import ( @@ -182,7 +182,7 @@ def build_model(self) -> TransformerModel: def build_embedding(self) -> Embedding: """Build an embedding table.""" - return Embedding( + return StandardEmbedding( num_embeddings=self.config.vocabulary_size, embedding_dim=self.config.model_dim, pad_idx=self.config.pad_idx, diff --git a/src/fairseq2/models/s2t_transformer/builder.py b/src/fairseq2/models/s2t_transformer/builder.py index 1803f6106..677cd3185 100644 --- a/src/fairseq2/models/s2t_transformer/builder.py +++ b/src/fairseq2/models/s2t_transformer/builder.py @@ -20,7 +20,7 @@ TransformerModel, ) from fairseq2.models.utils.arch_registry import ArchitectureRegistry -from fairseq2.nn.embedding import Embedding +from fairseq2.nn.embedding import StandardEmbedding from fairseq2.nn.position_encoder import PositionEncoder, SinusoidalPositionEncoder from fairseq2.nn.transformer import ( SDPA, @@ -283,7 +283,7 @@ def build_encoder_frontend(self) -> TransformerFrontend: def build_decoder_frontend(self) -> TransformerFrontend: """Build a Transformer decoder front-end.""" - embed = Embedding( + embed = StandardEmbedding( num_embeddings=self.config.target_vocabulary_size, embedding_dim=self.config.model_dim, pad_idx=self.config.target_pad_idx, diff --git a/src/fairseq2/nn/embedding.py b/src/fairseq2/nn/embedding.py index 16638fc69..19135ddd2 100644 --- a/src/fairseq2/nn/embedding.py +++ b/src/fairseq2/nn/embedding.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from abc import ABC, abstractmethod from typing import Optional, final import torch @@ -13,17 +14,65 @@ from torch.nn.functional import embedding from torch.nn.parameter import Parameter -from fairseq2.typing import DataType, Device +from fairseq2.typing import DataType, Device, finaloverride -@final -class Embedding(Module): +class Embedding(Module, ABC): """Stores embeddings of a fixed dictionary and size.""" num_embeddings: int embedding_dim: int pad_idx: Optional[int] padding_idx: Optional[int] # Compat + + def __init__( + self, num_embeddings: int, embedding_dim: int, pad_idx: Optional[int] = None + ) -> None: + """ + :param num_embeddings: + The size of the embedding table. + :param embedding_dim: + The dimensionality of returned embeddings. + :param pad_idx: + If not ``None``, entries at ``pad_idx`` do not contribute to the + gradient; therefore, the embedding at ``pad_idx`` is not updated + during training. + """ + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.pad_idx = pad_idx + + # Alias field for compatibility with `torch.nn.Embedding`. + self.padding_idx = pad_idx + + @abstractmethod + def forward(self, x: Tensor) -> Tensor: + """ + :param x: + The embedding indices. *Shape:* Any. + + :returns: + The embeddings corresponding to the specified indices. *Shape:* + :math:`(*,E)`, where :math:`*` is the input shape and :math:`E` is + the dimensionality of the embeddings. + """ + + def extra_repr(self) -> str: + """:meta private:""" + s = f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}" + + if self.pad_idx is not None: + s = f"{s}, pad_idx={self.pad_idx}" + + return s + + +@final +class StandardEmbedding(Embedding): + """Stores embeddings of a fixed dictionary and size in an in-memory table.""" + scaled: bool weight: Parameter @@ -51,16 +100,10 @@ def __init__( :math:`\\mathcal{N}(0, \\frac{1}{\\text{embedding_dim}})`; otherwise, from :math:`\\mathcal{N}(0, 1)`. """ - super().__init__() + super().__init__(num_embeddings, embedding_dim, pad_idx) - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - self.pad_idx = pad_idx self.scaled = scaled - # Alias field for compatibility with `torch.nn.Embedding`. - self.padding_idx = pad_idx - self.weight = Parameter( torch.empty((num_embeddings, embedding_dim), device=device, dtype=dtype) ) @@ -78,24 +121,13 @@ def reset_parameters(self) -> None: with torch.no_grad(): self.weight[self.pad_idx].fill_(0.0) + @finaloverride def forward(self, x: Tensor) -> Tensor: - """ - :param x: - The embedding indices. *Shape:* Any. - - :returns: - The embeddings corresponding to the specified indices. *Shape:* - :math:`(*,E)`, where :math:`*` is the input shape and :math:`E` is - the dimensionality of the embeddings. - """ return embedding(x, self.weight, self.pad_idx) def extra_repr(self) -> str: """:meta private:""" - s = f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}" - - if self.pad_idx is not None: - s = f"{s}, pad_idx={self.pad_idx}" + s = super().extra_repr() if self.scaled: s = f"{s}, scaled=True"