Skip to content

Commit

Permalink
add nn.linear to init for moe router
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Aug 5, 2024
1 parent a9dba53 commit 2efffb8
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/nanotron/scaling/parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class StandardParametrizator(Parametrizator):
def __init__(self, config: ModelArgs):
super().__init__(config)
self.MODULE_TO_PARAMETRIZE = {
nn.Linear: self._parametrize_column_linear,
TensorParallelColumnLinear: self._parametrize_column_linear,
TensorParallelRowLinear: self._parametrize_row_linear,
TritonRMSNorm: self._parametrize_layer_norm,
Expand Down

0 comments on commit 2efffb8

Please sign in to comment.