Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove GPTJ dma before mha #468

Merged
merged 6 commits into from
Oct 25, 2023
Merged

Conversation

BaihuiJin
Copy link
Contributor

What does this PR do?

Reduce max mem usage and increase perf.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@BaihuiJin
Copy link
Contributor Author

BaihuiJin commented Oct 17, 2023

Before this PR
BS64 max_token 100
Throughput (including tokenization) = 2251.854885296207 tokens/second
Memory allocated = 26.74 GB
Max memory allocated = 43.64 GB
Total memory available = 94.46 GB
Graph compilation duration = 11.01004623901099 seconds

@ZhaiFeiyue ZhaiFeiyue added the run-test Run CI for PRs from external contributors label Oct 17, 2023
@BaihuiJin
Copy link
Contributor Author

after this pr
Stats:
Throughput (including tokenization) = 2983.5717500079877 tokens/second
Memory allocated = 29.66 GB
Max memory allocated = 33.06 GB
Total memory available = 94.46 GB
Graph compilation duration = 12.692105805999745 seconds

@BaihuiJin
Copy link
Contributor Author

@regisss This pr is still WIP, there's another optimization pending, we could do two things in one PR: )

@regisss
Copy link
Collaborator

regisss commented Oct 17, 2023

@BaihuiJin Nice! No problem, let's wait a bit for the 2nd optimization and then we'll merge.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@BaihuiJin
Copy link
Contributor Author

BaihuiJin commented Oct 23, 2023

Lastest Perf Result, configs are the same as above
Stats:

Throughput (including tokenization) = 4275.093059788448 tokens/second
Memory allocated = 27.35 GB
Max memory allocated = 28.75 GB
Total memory available = 94.46 GB
Graph compilation duration = 10.280688997008838 seconds

@BaihuiJin
Copy link
Contributor Author

@regisss @ZhaiFeiyue Pls help review~

@ZhaiFeiyue
Copy link
Collaborator

@BaihuiJin nice perf improved, few comments added

Comment on lines +357 to +369
rotary_dim = self.config.rotary_dim
embed_dim = self.config.hidden_size
pos_embd_dim = rotary_dim or embed_dim
max_positions = self.config.max_position_embeddings
embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim).to(torch.bfloat16)
embed_positions = embed_positions.repeat(position_ids.shape[0], 1, 1)
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
sin = sin.contiguous()
cos = cos.contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that this piece of code comes from the code blocks that were removed above. Could this be moved to a dedicated method that would be called here please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theoretically it can be done, but test shows that an additional memcpy occurred, perf drop detail as follow.
Throughput (including tokenization) = 3885.6094019038055 tokens/second
Memory allocated = 27.33 GB
Max memory allocated = 28.73 GB
Total memory available = 94.46 GB
Graph compilation duration = 8.958231755999805 seconds

Copy link
Contributor Author

@BaihuiJin BaihuiJin Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, changes looks like this
def get_embed_positions(embed_positions, position_ids):
embed_positions = embed_positions.repeat(position_ids.shape[0], 1, 1)
if embed_positions.device != position_ids.device:
embed_positions = embed_positions.to(position_ids.device)
return embed_positions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising as objects are passed to functions by references if I'm not mistaken.
Okay, in that case, could you just add a comment above this block saying which methods it replaces, and also add a blank line right above and below please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising as objects are passed to functions by references if I'm not mistaken. Okay, in that case, could you just add a comment above this block saying which methods it replaces, and also add a blank line right above and below please?

Surprising indeed. Anyway, changed accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising as objects are passed to functions by references if I'm not mistaken. Okay, in that case, could you just add a comment above this block saying which methods it replaces, and also add a blank line right above and below please?

By the way, I think the make style removed blank line I added below this block : )

@ZhaiFeiyue ZhaiFeiyue added run-test Run CI for PRs from external contributors and removed run-test Run CI for PRs from external contributors labels Oct 25, 2023
Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@regisss regisss merged commit a741af7 into huggingface:main Oct 25, 2023
11 of 12 checks passed
@regisss regisss mentioned this pull request Jul 29, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-test Run CI for PRs from external contributors
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants