From 0e71106cff50d842e4c0fab836c7fec164924fc4 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Wed, 13 Nov 2024 14:57:12 +0800 Subject: [PATCH] Clean up bench Signed-off-by: Austin Liu --- benchmark/scripts/benchmark_dpo_loss.py | 39 +++++++++++-------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index fa492ab9b..8593985ac 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -67,7 +67,7 @@ def forward(self, x, target): def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - B = input.x // 2 + B = input.x T = input.extra_benchmark_config["T"] H = input.extra_benchmark_config["H"] V = input.extra_benchmark_config["V"] @@ -85,14 +85,14 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias ).to(device) - # Input shape: [B*2, T, H] - _input = torch.randn(B * 2, T, H, requires_grad=True, dtype=dtype, device=device) - # Target shape: [B*2, T] - target = torch.randint(V, (B * 2, T), dtype=torch.long, device=device) + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) + # Target shape: [B, T] + target = torch.randint(V, (B, T), dtype=torch.long, device=device) # Add ignore_index tokens to simulate padding - num_elements_to_assign = torch.randint(1, B * 2 * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * 2 * T)[:num_elements_to_assign] + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index def fwd(): @@ -114,7 +114,7 @@ def full(): def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - B = input.x // 2 + B = input.x T = input.extra_benchmark_config["T"] H = input.extra_benchmark_config["H"] V = input.extra_benchmark_config["V"] @@ -133,27 +133,22 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias ).to(device) - # Input shape: [B*2, T, H] - input1 = torch.randn( - B * 2, T, H, device=device, dtype=dtype - ) # .detach().clone().requires_grad_(True) - input2 = torch.randn( - B * 2, T, H, device=device, dtype=dtype - ) # .detach().clone().requires_grad_(True) + # Input shape: [B, T, H] + _input = torch.randn(B, T, H, device=device, dtype=dtype) - # Target shape: [B*2, T] - target = torch.randint(0, V, (B * 2, T), device=device, dtype=torch.long) + # Target shape: [B, T] + target = torch.randint(V, (B, T), device=device, dtype=torch.long) # Add ignore_index tokens - num_elements_to_assign = torch.randint(1, B * 2 * T // 2, (1,)).item() - indices_to_assign = torch.randperm(B * 2 * T)[:num_elements_to_assign] + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index def fwd(): if provider == "liger": - return liger_dpo_loss(input1, target) + return liger_dpo_loss(_input, target) elif provider == "huggingface": - return torch_dpo_loss(input2, target) + return torch_dpo_loss(_input, target) if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench( @@ -165,7 +160,7 @@ def fwd(): y = fwd() ms_50, ms_20, ms_80 = triton.testing.do_bench( lambda: y.backward(retain_graph=True), - grad_to_none=[input1, input2], + grad_to_none=[_input], rep=100, quantiles=QUANTILES, )