-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add TE skeleton * Update NeoXArgs docs automatically * added option for te version of norms * import TERMSNorm * add te norm options to norm arg * add TE objects in weight decay function * reformat * add TERMSNorm and TELayerNorm * Update NeoXArgs docs automatically * - add Fused RMS Norm from apex * - make it consistent with how layernorm looks * Merged transformer engine and apex fused layernorm branches * Added assertion if TE is used * Removed unnecessary transformer-engine import * Changed importerror text for TE * Added requirements/requirements-transformerengine.txt * Add TE skeleton * Update NeoXArgs docs automatically * added option for te version of norms * import TERMSNorm * add te norm options to norm arg * add TE objects in weight decay function * reformat * add TERMSNorm and TELayerNorm * Update NeoXArgs docs automatically * - add Fused RMS Norm from apex * - make it consistent with how layernorm looks * Merged transformer engine and apex fused layernorm branches * Added assertion if TE is used * Removed unnecessary transformer-engine import * Changed importerror text for TE * Added requirements/requirements-transformerengine.txt * update comments * precommit --------- Co-authored-by: Quentin Anthony <[email protected]> Co-authored-by: github-actions <[email protected]> Co-authored-by: lintangsutawika <lintang@stella-ord-0.stella-ord.tenant-eleutherai.svc.tenant.chi.local> Co-authored-by: lintangsutawika <[email protected]> Co-authored-by: dmahan93 <[email protected]> Co-authored-by: aurelion-source <[email protected]> Co-authored-by: aurelion-source <[email protected]>
- Loading branch information
1 parent
77e8158
commit 836aefa
Showing
9 changed files
with
293 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch | ||
|
||
try: | ||
import transformer_engine as te | ||
except ImportError: | ||
raise ImportError( | ||
"Unable to import transformer-engine. Please refer to " | ||
"https://github.com/NVIDIA/TransformerEngine for installation instructions." | ||
) | ||
|
||
|
||
class TERMSNorm(torch.nn.Module): | ||
def __init__(self, dim, eps=1e-8, **kwargs): | ||
""" | ||
A conditional wrapper to initialize an instance of Transformer-Engine's | ||
`RMSNorm` based on input | ||
:param dim: model size | ||
:param eps: epsilon value, default 1e-8 | ||
""" | ||
super(TERMSNorm, self).__init__() | ||
|
||
self.d = dim | ||
self.eps = eps | ||
self.norm = te.pytorch.RMSNorm( | ||
hidden_size=self.d, | ||
eps=self.eps, | ||
**kwargs, | ||
) | ||
|
||
def forward(self, x): | ||
return self.norm(x) | ||
|
||
|
||
class TELayerNorm(torch.nn.Module): | ||
def __init__(self, dim, eps=1.0e-5, **kwargs): | ||
""" | ||
A conditional wrapper to initialize an instance of Transformer-Engine's | ||
`LayerNorm` based on input | ||
:param dim: model size | ||
:param eps: epsilon value, default 1.0e-5 | ||
""" | ||
super(TELayerNorm, self).__init__() | ||
|
||
self.d = dim | ||
self.eps = eps | ||
self.norm = te.pytorch.LayerNorm( | ||
hidden_size=self.d, | ||
eps=self.eps, | ||
**kwargs, | ||
) | ||
|
||
def forward(self, x): | ||
return self.norm(x) | ||
|
||
|
||
class TELinear(te.pytorch.Linear): | ||
""" | ||
Wrapper for the Transformer-Engine's `Linear` layer. | ||
""" | ||
|
||
def __init__(self): | ||
# TODO | ||
return | ||
|
||
def forward(self, x): | ||
# TODO | ||
return | ||
|
||
|
||
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): | ||
""" | ||
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines | ||
layernorm and linear layers | ||
""" | ||
|
||
def __init__(self): | ||
# TODO | ||
return | ||
|
||
def forward(self, x): | ||
# TODO | ||
return | ||
|
||
|
||
class TEColumnParallelLinear(TELinear): | ||
""" | ||
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar | ||
to megatron's `ColumnParallelLinear` layer. | ||
""" | ||
|
||
def __init__(self): | ||
# TODO | ||
return | ||
|
||
def forward(self, x): | ||
# TODO | ||
return | ||
|
||
|
||
class TERowParallelLinear(TELinear): | ||
""" | ||
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar | ||
to megatron's `RowParallelLinear` layer. | ||
""" | ||
|
||
def __init__(self): | ||
# TODO | ||
return | ||
|
||
def forward(self, x): | ||
# TODO | ||
return | ||
|
||
|
||
class TEDotProductAttention(te.pytorch.DotProductAttention): | ||
""" | ||
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also | ||
has "flash attention" enabled. | ||
""" | ||
|
||
def __init__(self): | ||
# TODO | ||
return | ||
|
||
def forward(self, x): | ||
# TODO | ||
return | ||
|
||
|
||
class TEDelayedScaling(te.common.recipe.DelayedScaling): | ||
""" | ||
Wrapper for the Transformer-Engine's `DelayedScaling` layer. | ||
""" | ||
|
||
def __init__(self): | ||
# TODO | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters