Skip to content

Commit

Permalink
Add normalization support for layer_norm within ConformerConvolution. (
Browse files Browse the repository at this point in the history
…#92)

* Add normalization options within the ConformerConvolution module.

* Address nit comments.
  • Loading branch information
kauterry authored Oct 6, 2023
1 parent 455812e commit efccad1
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions src/fairseq2/models/conformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import Tensor
from torch.nn import GLU, BatchNorm1d, Conv1d, Module, SiLU

from fairseq2.nn.normalization import LayerNorm, StandardLayerNorm
from fairseq2.nn.utils.mask import apply_padding_mask
from fairseq2.typing import DataType, Device

Expand All @@ -21,7 +22,8 @@ class ConformerConvolution(Module):
pointwise_conv1: Conv1d
pointwise_conv1_activation: GLU
depthwise_conv: Conv1d
batch_norm: BatchNorm1d
batch_norm: Optional[BatchNorm1d]
layer_norm: Optional[LayerNorm]
depthwise_activation: Module
pointwise_conv2: Conv1d

Expand All @@ -30,6 +32,7 @@ def __init__(
model_dim: int,
depthwise_kernel_size: int,
*,
norm_type: str = "batch_norm",
depthwise_activation: Optional[Module] = None,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
Expand All @@ -39,6 +42,8 @@ def __init__(
The dimensionality of the model.
:param depthwise_kernel_size:
The kernel size of the depthwise convolution.
:param norm_type:
The type of norm layer applied after the depthwise convolution.
:param depthwise_activation:
The activation to apply to outputs of the depthwise convolution. If
``None``, :func:`~torch.nn.SiLU` (a.k.a. swish) will be used.
Expand Down Expand Up @@ -74,7 +79,22 @@ def __init__(
dtype=dtype,
)

self.batch_norm = BatchNorm1d(model_dim, device=device, dtype=dtype)
if norm_type not in ("batch_norm", "layer_norm"):
raise ValueError(
f"`norm_type` must be 'batch_norm' or 'layer_norm', but is '{norm_type}' instead."
)

if norm_type == "batch_norm":
self.batch_norm = BatchNorm1d(model_dim, device=device, dtype=dtype)
else:
self.register_module("batch_norm", None)

if norm_type == "layer_norm":
self.layer_norm = StandardLayerNorm(
model_dim, bias=True, device=device, dtype=dtype
)
else:
self.register_module("layer_norm", None)

if depthwise_activation is None:
self.depthwise_activation = SiLU() # a.k.a. swish
Expand Down Expand Up @@ -119,7 +139,11 @@ def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
# (N, M, S) -> (N, M, S)
seqs = self.depthwise_conv(seqs)

seqs = self.batch_norm(seqs)
if self.batch_norm is not None:
seqs = self.batch_norm(seqs)
else:
assert self.layer_norm is not None
seqs = self.layer_norm(seqs)

seqs = self.depthwise_activation(seqs)

Expand Down

0 comments on commit efccad1

Please sign in to comment.