Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformerEngine Integration #1282

Merged
merged 30 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5382c8a
Implemented ColumnParallelLinear with Transformer-Engine
aurelion-source Sep 15, 2024
fa887b7
Implemented RowParallelLinear with Transformer-Engine
aurelion-source Sep 15, 2024
0a6f140
Implemented LayerNormMLP with Transformer-Engine
aurelion-source Sep 15, 2024
5cba717
Implemented MultiheadAttention with Transformer-Engine
aurelion-source Sep 16, 2024
94e552c
Cleaned up transformer.py
aurelion-source Sep 16, 2024
40e1019
Cleaned up neox_args
aurelion-source Sep 16, 2024
885e72c
Cleaned up neox_args
aurelion-source Sep 16, 2024
fe8f22a
- Fixed TE_MHA and added rope support
aurelion-source Sep 26, 2024
ee42a31
Fixed mixed files.
aurelion-source Sep 26, 2024
8961dd7
Implemented ColumnParallelLinear with Transformer-Engine
aurelion-source Sep 15, 2024
5162d54
Implemented RowParallelLinear with Transformer-Engine
aurelion-source Sep 15, 2024
3cad89c
Implemented LayerNormMLP with Transformer-Engine
aurelion-source Sep 15, 2024
eedb6c2
Implemented MultiheadAttention with Transformer-Engine
aurelion-source Sep 16, 2024
36ad680
Cleaned up transformer.py
aurelion-source Sep 16, 2024
6963103
Cleaned up neox_args
aurelion-source Sep 16, 2024
afc9c92
Cleaned up neox_args
aurelion-source Sep 16, 2024
a0e7acd
- Fixed TE_MHA and added rope support
aurelion-source Sep 26, 2024
0b4bdc5
Fixed mixed files.
aurelion-source Sep 26, 2024
d559be9
Merge branch 'te' of github.com:aurelion-source/gpt-neox into te
aurelion-source Sep 27, 2024
bb76510
Changed get_linear name
aurelion-source Sep 27, 2024
43cf4ee
Added rng tracker to lnmlp and placed rope in te_mha init instead of …
aurelion-source Oct 1, 2024
42716f2
Updated fp8 arguments to te_fp8
aurelion-source Oct 1, 2024
b3255e6
Added EAI copyright
aurelion-source Oct 1, 2024
afeff03
Merge branch 'main' into te
Quentin-Anthony Oct 8, 2024
98f0388
precommit
Quentin-Anthony Oct 8, 2024
7e7dbfb
add sample TE config
Quentin-Anthony Oct 8, 2024
5757be6
add te to readme
Quentin-Anthony Oct 8, 2024
9ea3dcf
remove pip install prefix from reqs file
Quentin-Anthony Oct 8, 2024
f3e40e9
Force TE pytorch in requirements file
aurelion-source Oct 16, 2024
7bf9c44
Merge remote-tracking branch 'upstream/main' into te
Quentin-Anthony Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def forward(self, x, seq_dim=1):

class RotaryEmbedding(torch.nn.Module):
def __init__(
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False, return_embeddings=False
):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
Expand All @@ -49,6 +49,7 @@ def __init__(
self.max_seq_len = max_seq_len
self.base = base
self.dim = dim
self.return_embeddings = return_embeddings

# precompute cos_cached, sin_cached in fp32
cos_cached, sin_cached, inv_freq = self._prepare_cache(
Expand All @@ -67,6 +68,8 @@ def _prepare_cache(self, seq_len, precision, base):
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)

self.emb = emb.reshape(emb.size(0), 1, 1, emb.size(1))

cos_cached = emb.cos()[:, None, None, :]
sin_cached = emb.sin()[:, None, None, :]

Expand All @@ -77,6 +80,8 @@ def _prepare_cache(self, seq_len, precision, base):
)

def forward(self, x, seq_dim=0, seq_len=None):
if self.return_embeddings:
return self.emb.to(self.precision).to(x.device)
if seq_len is None:
seq_len = x.shape[seq_dim]

Expand Down Expand Up @@ -249,4 +254,4 @@ def forward(self, x):
a.shape[0], 1, a.shape[2]
) # seq_len_k - 1 points to the last token index in the current inference batch.

return x + a
return x + a
Loading
Loading