Skip to content

Commit

Permalink
Fix unwanted scale/bias while testing and simplify _test_memory funct…
Browse files Browse the repository at this point in the history
…ion (#50)

## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Fix unwanted scale/bias while testing and simplify _test_memory function
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: shisahni <[email protected]>
  • Loading branch information
shivam15s and shisahni authored Aug 20, 2024
1 parent 27d2d51 commit b3f9e7a
Show file tree
Hide file tree
Showing 12 changed files with 22 additions and 36 deletions.
8 changes: 4 additions & 4 deletions benchmark/benchmark_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def bench_speed_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"):
triton_rms = LigerRMSNorm(hidden_size=N).to("cuda")
llama_rms = LlamaRMSNorm(hidden_size=N).to("cuda")

x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
dy = 0.1 * torch.randn_like(x)
x = torch.randn(x_shape, dtype=dtype, device="cuda")
dy = torch.randn_like(x)
x.requires_grad_(True)
quantiles = [0.5, 0.2, 0.8]

Expand Down Expand Up @@ -144,8 +144,8 @@ def bench_memory_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"):
triton_rms = LigerRMSNorm(hidden_size=N).to("cuda")
llama_rms = LlamaRMSNorm(hidden_size=N).to("cuda")

x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
dy = 0.1 * torch.randn_like(x)
x = torch.randn(x_shape, dtype=dtype, device="cuda")
dy = torch.randn_like(x)
x.requires_grad_(True)

# utility functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ V,Liger,Hugging Face
8192.0,1.5,3.7
16384.0,2.8,7.8
32768.0,5.5,15.7
65536.0,12.0,30.8
65536.0,12.0,30.9
131072.0,25.3,61.4
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmark/cross_entropy_speed/cross-entropy-fwd-speed-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions benchmark/rms_norm_memory/rmsnorm-full-memory-benchmark.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
N,Liger,Hugging Face
1024.0,32.0,79.6
2048.0,64.0,159.2
4096.0,128.1,318.5
8192.0,236.1,636.9
16384.0,368.2,1273.8
32768.0,624.4,2547.6
1024.0,32.0,87.6
2048.0,64.0,175.2
4096.0,128.1,350.5
8192.0,236.1,700.9
16384.0,368.2,1401.8
32768.0,624.4,2803.6
Binary file modified benchmark/rms_norm_memory/rmsnorm-full-memory-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions benchmark/rms_norm_speed/rmsnorm-bwd-speed-benchmark.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
N,Liger,Hugging Face
1024.0,0.2,0.2
2048.0,0.1,0.4
2048.0,0.2,0.4
4096.0,0.2,0.7
8192.0,0.2,1.4
16384.0,0.3,2.6
32768.0,1.0,5.0
32768.0,1.0,5.1
Binary file modified benchmark/rms_norm_speed/rmsnorm-bwd-speed-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions benchmark/rms_norm_speed/rmsnorm-full-speed-benchmark.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
N,Liger,Hugging Face
1024.0,0.3,0.3
2048.0,0.3,0.5
4096.0,0.4,1.0
8192.0,0.3,1.9
16384.0,0.4,3.6
32768.0,1.2,6.9
1024.0,0.5,0.3
2048.0,0.5,0.5
4096.0,0.5,1.0
8192.0,0.5,1.9
16384.0,0.5,3.6
32768.0,1.2,7.0
Binary file modified benchmark/rms_norm_speed/rmsnorm-full-speed-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmark/rms_norm_speed/rmsnorm-fwd-speed-benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 3 additions & 17 deletions benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,13 @@
import torch


def _test_memory_once(func: Callable) -> float:

torch.cuda.memory._record_memory_history()
torch.cuda.memory.reset_peak_memory_stats()

func()

mem = torch.cuda.max_memory_allocated()

# uncomment to save the visual memory snapshot
# torch.cuda.memory._dump_snapshot(f"{func.__name__}.pickle")

torch.cuda.memory._record_memory_history(enabled=None)
return mem


def _test_memory(func: Callable, _iter: int = 10) -> float:
total_mem = []

for _ in range(_iter):
mem = _test_memory_once(func)
torch.cuda.memory.reset_peak_memory_stats()
func()
mem = torch.cuda.max_memory_allocated()
total_mem.append(mem)

return sum(total_mem) / len(total_mem)
Expand Down

0 comments on commit b3f9e7a

Please sign in to comment.