Skip to content

Commit

Permalink
Clean up bench
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Nov 14, 2024
1 parent 5bba92c commit 0e71106
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(self, x, target):


def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x // 2
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
Expand All @@ -85,14 +85,14 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)

# Input shape: [B*2, T, H]
_input = torch.randn(B * 2, T, H, requires_grad=True, dtype=dtype, device=device)
# Target shape: [B*2, T]
target = torch.randint(V, (B * 2, T), dtype=torch.long, device=device)
# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * 2 * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * 2 * T)[:num_elements_to_assign]
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

def fwd():
Expand All @@ -114,7 +114,7 @@ def full():


def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x // 2
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
Expand All @@ -133,27 +133,22 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)

# Input shape: [B*2, T, H]
input1 = torch.randn(
B * 2, T, H, device=device, dtype=dtype
) # .detach().clone().requires_grad_(True)
input2 = torch.randn(
B * 2, T, H, device=device, dtype=dtype
) # .detach().clone().requires_grad_(True)
# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)

# Target shape: [B*2, T]
target = torch.randint(0, V, (B * 2, T), device=device, dtype=torch.long)
# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)

# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * 2 * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * 2 * T)[:num_elements_to_assign]
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

def fwd():
if provider == "liger":
return liger_dpo_loss(input1, target)
return liger_dpo_loss(_input, target)
elif provider == "huggingface":
return torch_dpo_loss(input2, target)
return torch_dpo_loss(_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
Expand All @@ -165,7 +160,7 @@ def fwd():
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[input1, input2],
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
Expand Down

0 comments on commit 0e71106

Please sign in to comment.