diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..380ad460 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -34,6 +34,7 @@ class StandardParametrizator(Parametrizator): def __init__(self, config: ModelArgs): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { + nn.Linear: self._parametrize_column_linear, TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm,