Skip to content

Commit

Permalink
Add Chunked ORPO Loss (#362)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Adds chunked ORPO loss kernel 
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
Benchmarks
![Speed
ORPO](https://github.com/user-attachments/assets/ae9e6f67-14cd-4189-9d64-9a2f94a3b3c6)
![Mem
ORPO](https://github.com/user-attachments/assets/47c289f4-2876-4530-949c-2c2825bc0f79)

References:
1. #227 
2.
https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py
<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: shisahni_LinkedIn <[email protected]>
  • Loading branch information
shivam15s and shisahni_LinkedIn authored Nov 14, 2024
1 parent 523fd66 commit 6b2fd02
Show file tree
Hide file tree
Showing 8 changed files with 700 additions and 0 deletions.
48 changes: 48 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,51 @@ layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160
layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,2,116.00621032714844,116.00621032714844,116.00621032714844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,4,230.83609008789062,230.83609008789062,230.83609008789062,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,8,461.9543151855469,461.9543151855469,461.9543151855469,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,16,922.994384765625,922.994384765625,922.994384765625,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0
fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,2,39.558860778808594,39.52657699584961,39.591148376464844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0
fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,4,79.9734115600586,79.9734115600586,79.9734115600586,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0
fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,8,160.071044921875,160.071044921875,160.071044921875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0
fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,16,321.4681091308594,321.4681091308594,321.4681091308594,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0
fused_linear_orpo_loss,liger,full,speed,ms,B,B,2,116.56009674072266,116.56009674072266,116.56009674072266,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0
fused_linear_orpo_loss,liger,full,speed,ms,B,B,4,232.43980407714844,232.43980407714844,232.43980407714844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0
fused_linear_orpo_loss,liger,full,speed,ms,B,B,8,464.5750732421875,464.5750732421875,464.5750732421875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0
fused_linear_orpo_loss,liger,full,speed,ms,B,B,16,926.3385009765625,926.3385009765625,926.3385009765625,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0
fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,2,120.68428802490234,120.68428802490234,120.68428802490234,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0
fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,4,241.15061950683594,241.15061950683594,241.15061950683594,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0
fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,8,492.5342102050781,492.5342102050781,492.5342102050781,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0
fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,16,1000.8460693359375,1000.8460693359375,1000.8460693359375,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0
fused_linear_orpo_loss,liger,full,memory,MB,B,B,2,14556.626953125,14556.626953125,14556.626953125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0
fused_linear_orpo_loss,liger,full,memory,MB,B,B,4,14748.689453125,14748.689453125,14748.689453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0
fused_linear_orpo_loss,liger,full,memory,MB,B,B,8,15132.814453125,15132.814453125,15132.814453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0
fused_linear_orpo_loss,liger,full,memory,MB,B,B,16,15901.064453125,15901.064453125,15901.064453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,2,12488.501953125,12488.501953125,12488.501953125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,4,19630.564453125,19630.564453125,19630.564453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,8,33914.6875,33914.6875,33914.6875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,16,62482.9375,62482.9375,62482.9375,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,2,31.02783966064453,31.027551651000977,31.164947509765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,4,60.88966369628906,60.88966369628906,60.88966369628906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,8,121.08070373535156,121.08070373535156,121.08070373535156,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0
fused_linear_orpo_loss,liger,forward,speed,ms,B,B,16,244.36968994140625,244.36968994140625,244.36968994140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0
fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,2,12.9093599319458,12.874624252319336,12.947936058044434,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0
fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,4,25.557632446289062,25.526700973510742,25.703763961791992,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0
fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,8,51.75590515136719,51.75590515136719,51.75590515136719,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0
fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,16,103.8515853881836,103.8515853881836,103.8515853881836,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0
fused_linear_orpo_loss,liger,full,speed,ms,B,B,2,32.52537536621094,32.49258041381836,32.558170318603516,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0
fused_linear_orpo_loss,liger,full,speed,ms,B,B,4,63.16300964355469,63.16300964355469,63.16300964355469,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0
fused_linear_orpo_loss,liger,full,speed,ms,B,B,8,123.02518463134766,123.02518463134766,123.02518463134766,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0
fused_linear_orpo_loss,liger,full,speed,ms,B,B,16,247.44105529785156,247.44105529785156,247.44105529785156,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0
fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,2,39.32752227783203,39.32701873779297,39.32802200317383,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0
fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,4,77.9202880859375,77.9202880859375,77.9202880859375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0
fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,8,151.6084442138672,151.6084442138672,151.6084442138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0
fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,16,304.4580993652344,304.4580993652344,304.4580993652344,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0
fused_linear_orpo_loss,liger,full,memory,MB,B,B,2,8161.34619140625,8161.34619140625,8161.34619140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0
fused_linear_orpo_loss,liger,full,memory,MB,B,B,4,8209.361328125,8209.361328125,8209.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0
fused_linear_orpo_loss,liger,full,memory,MB,B,B,8,8305.392578125,8305.392578125,8305.392578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0
fused_linear_orpo_loss,liger,full,memory,MB,B,B,16,8497.455078125,8497.455078125,8497.455078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0
fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,16,33418.421875,33418.421875,33418.421875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0
191 changes: 191 additions & 0 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
import sys

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchLMHeadORPO(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
from test.chunked_loss.test_orpo_loss import HF_ORPO_Loss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.orpo_loss = HF_ORPO_Loss().get_batch_loss_metrics

def forward(self, x, y):
return self.orpo_loss(x, self.lin.weight, y)


class LigerLMHeadORPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.orpo_loss = LigerFusedLinearORPOFunction.apply

def forward(self, x, y):
return self.orpo_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################


def bench_memory_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################


def bench_speed_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "fused_linear_orpo_loss",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_fused_linear_orpo_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_orpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
Empty file.
Empty file.
Loading

0 comments on commit 6b2fd02

Please sign in to comment.