diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index dfd31091c..a5126f1dd 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -619,3 +619,51 @@ layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160 layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,2,116.00621032714844,116.00621032714844,116.00621032714844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,4,230.83609008789062,230.83609008789062,230.83609008789062,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,8,461.9543151855469,461.9543151855469,461.9543151855469,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,16,922.994384765625,922.994384765625,922.994384765625,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:05,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,2,39.558860778808594,39.52657699584961,39.591148376464844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,4,79.9734115600586,79.9734115600586,79.9734115600586,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,8,160.071044921875,160.071044921875,160.071044921875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,16,321.4681091308594,321.4681091308594,321.4681091308594,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:24:36,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,2,116.56009674072266,116.56009674072266,116.56009674072266,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,4,232.43980407714844,232.43980407714844,232.43980407714844,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,8,464.5750732421875,464.5750732421875,464.5750732421875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,16,926.3385009765625,926.3385009765625,926.3385009765625,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:17,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,2,120.68428802490234,120.68428802490234,120.68428802490234,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,4,241.15061950683594,241.15061950683594,241.15061950683594,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,8,492.5342102050781,492.5342102050781,492.5342102050781,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,16,1000.8460693359375,1000.8460693359375,1000.8460693359375,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:25:58,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,2,14556.626953125,14556.626953125,14556.626953125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,4,14748.689453125,14748.689453125,14748.689453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,8,15132.814453125,15132.814453125,15132.814453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,16,15901.064453125,15901.064453125,15901.064453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:26:42,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,2,12488.501953125,12488.501953125,12488.501953125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,4,19630.564453125,19630.564453125,19630.564453125,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,8,33914.6875,33914.6875,33914.6875,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,16,62482.9375,62482.9375,62482.9375,"{""T"": 4096, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 21:27:10,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,2,31.02783966064453,31.027551651000977,31.164947509765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,4,60.88966369628906,60.88966369628906,60.88966369628906,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,8,121.08070373535156,121.08070373535156,121.08070373535156,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0 +fused_linear_orpo_loss,liger,forward,speed,ms,B,B,16,244.36968994140625,244.36968994140625,244.36968994140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:30,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,2,12.9093599319458,12.874624252319336,12.947936058044434,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,4,25.557632446289062,25.526700973510742,25.703763961791992,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,8,51.75590515136719,51.75590515136719,51.75590515136719,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0 +fused_linear_orpo_loss,huggingface,forward,speed,ms,B,B,16,103.8515853881836,103.8515853881836,103.8515853881836,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:06:57,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,2,32.52537536621094,32.49258041381836,32.558170318603516,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,4,63.16300964355469,63.16300964355469,63.16300964355469,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,8,123.02518463134766,123.02518463134766,123.02518463134766,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0 +fused_linear_orpo_loss,liger,full,speed,ms,B,B,16,247.44105529785156,247.44105529785156,247.44105529785156,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:28,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,2,39.32752227783203,39.32701873779297,39.32802200317383,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,4,77.9202880859375,77.9202880859375,77.9202880859375,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,8,151.6084442138672,151.6084442138672,151.6084442138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0 +fused_linear_orpo_loss,huggingface,full,speed,ms,B,B,16,304.4580993652344,304.4580993652344,304.4580993652344,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:07:59,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,2,8161.34619140625,8161.34619140625,8161.34619140625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,4,8209.361328125,8209.361328125,8209.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,8,8305.392578125,8305.392578125,8305.392578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0 +fused_linear_orpo_loss,liger,full,memory,MB,B,B,16,8497.455078125,8497.455078125,8497.455078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:30,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 +fused_linear_orpo_loss,huggingface,full,memory,MB,B,B,16,33418.421875,33418.421875,33418.421875,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-13 22:08:56,0.4.0 diff --git a/benchmark/scripts/benchmark_orpo_loss.py b/benchmark/scripts/benchmark_orpo_loss.py new file mode 100644 index 000000000..dda42d772 --- /dev/null +++ b/benchmark/scripts/benchmark_orpo_loss.py @@ -0,0 +1,191 @@ +import os +import sys + +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchLMHeadORPO(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based cross entropy loss. + + :param H: hidden size + :param V: vocab size + :param ignore_index: index to ignore + :param reduction: reduction method + """ + + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + from test.chunked_loss.test_orpo_loss import HF_ORPO_Loss + + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.orpo_loss = HF_ORPO_Loss().get_batch_loss_metrics + + def forward(self, x, y): + return self.orpo_loss(x, self.lin.weight, y) + + +class LigerLMHeadORPO(torch.nn.Module): + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.orpo_loss = LigerFusedLinearORPOFunction.apply + + def forward(self, x, y): + return self.orpo_loss(x, self.lin.weight, y) + + +############################################################################# +# Test the memory consumption of the linear fused cross entropy loss +############################################################################# + + +def bench_memory_fused_linear_orpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + + device = "cuda" + torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_orpo(_input, target) + elif provider == "huggingface": + return torch_lm_head_orpo(_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +# ############################################################################# +# # Test the speed of the fused linear cross entropy loss +# ############################################################################# + + +def bench_speed_fused_linear_orpo_loss( + input: SingleBenchmarkRunInput, +) -> SingleBenchmarkRunOutput: + B = input.x + T = input.extra_benchmark_config["T"] + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + device = "cuda" + + torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device) + + _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) + target = torch.randint(V, (B, T), dtype=torch.long, device=device) + + def fwd(): + if provider == "liger": + return liger_lm_head_orpo(_input, target) + elif provider == "huggingface": + return torch_lm_head_orpo(_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[_input], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "fused_linear_orpo_loss", + "x_name": "B", + "x_label": "B", + "x_values": [2**i for i in range(1, 5)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "T": 1024, + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_fused_linear_orpo_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_fused_linear_orpo_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/chunked_loss/README.md b/src/liger_kernel/chunked_loss/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py new file mode 100644 index 000000000..c95aa40ed --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -0,0 +1,107 @@ +import torch + + +class LigerFusedLinearPreferenceBase(torch.autograd.Function): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + loss_fn=None, + chunk_size=1, + compiled=True, + ): + """ + Base class for fused linear layer with preference loss. + Expects _input to be stacked with chosen and rejected inputs on the batch dimension. + + Args: + _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). + compiled (bool): Whether to use torch compile for chunk accumulation. + """ + # TODO: Tune CHUNK_SIZE to fully utilize the GPU + CHUNK_SIZE = chunk_size + + grad_weight = torch.zeros_like(weight) + grad_chosen_inputs = [] + grad_rejected_inputs = [] + grad_bias = torch.zeros_like(bias) if bias is not None else None + loss_acc = torch.zeros((), device=_input.device) + + chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) + + def accumulate_chunk(input_chunk, target_chunk): + if bias is not None: + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( + chunk_loss, + (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), + ) = torch.func.grad_and_value(loss_fn, argnums=(0, 1, 3), has_aux=True)( + input_chunk, weight, target_chunk, bias + ) + grad_bias.add_(chunk_grad_bias) + else: + (chunk_grad_input, chunk_grad_weight), ( + chunk_loss, + (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), + ) = torch.func.grad_and_value(loss_fn, argnums=(0, 1), has_aux=True)( + input_chunk, weight, target_chunk + ) + grad_weight.add_(chunk_grad_weight) + loss_acc.add_(chunk_loss) + return chunk_grad_input + + len_chosen = target.shape[0] // 2 + _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) + _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) + _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) + _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) + + for ( + chosen_input_chunk, + rejected_input_chunk, + chosen_target_chunk, + rejected_target_chunk, + ) in zip( + _chosen_input_chunks, + _rejected_input_chunks, + _chosen_target_chunks, + _rejected_target_chunks, + ): + input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) + target_chunk = torch.cat( + [chosen_target_chunk, rejected_target_chunk], dim=0 + ) + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + grad_input = accumulate_chunk(input_chunk, target_chunk) + + grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]]) + grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :]) + + # combine grad_chosen_inputs and grad_rejected_inputs + grad_inputs = grad_chosen_inputs + grad_rejected_inputs + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + return loss_acc + + @staticmethod + def backward(ctx, grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + grad_bias = grad_bias * grad_output if grad_bias is not None else None + + return grad_input, grad_weight, None, grad_bias, None, None, None diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py new file mode 100644 index 000000000..1cd6fe21e --- /dev/null +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -0,0 +1,117 @@ +from functools import partial + +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) + + +def odds_ratio_loss(chosen_logps, rejected_logps, beta=0.1): + """ + Compute odds-ratio loss. + Args: + chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). + rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + beta (float): Weight for the odds ratio loss. + """ + log_odds = (chosen_logps - rejected_logps) - ( + torch.log1p(-torch.exp(chosen_logps)) - torch.log1p(-torch.exp(rejected_logps)) + ) + ratio = F.logsigmoid(log_odds) + return beta * ratio.sum() + + +def _compute_orpo_loss( + input_chunk, + weight, + target_chunk, + bias=None, + full_target=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=True, +): + """ + Compute ORPO loss for a chunk of input and target. + Args: + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + beta (float): Weight for the odds ratio loss. + """ + len_chosen_chunk = target_chunk.shape[0] // 2 + + logits_chunk = input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) + + chosen_nll_loss = 0.0 + if compute_nll_loss: + chosen_nll_loss = F.nll_loss( + log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), + target_chunk[:len_chosen_chunk].view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + + loss_mask = target_chunk != ignore_index + label_chunk = torch.where(loss_mask, target_chunk, 0) + + per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) + average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + + chosen_logps = average_log_prob[:len_chosen_chunk] + rejected_logps = average_log_prob[len_chosen_chunk:] + + or_loss = odds_ratio_loss(chosen_logps, rejected_logps, beta=beta) + or_loss = or_loss / (full_target.shape[0] // 2) + + loss = chosen_nll_loss - or_loss + return loss, (or_loss, chosen_logps, rejected_logps) + + +class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + beta=0.1, + compute_nll_loss=True, + compiled=True, + ): + """ + Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss. + Handles both the forward and backward pass of the final linear layer with ORPO loss. + Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. + """ + orpo_loss_fn = partial( + _compute_orpo_loss, + full_target=target, + ignore_index=ignore_index, + beta=beta, + compute_nll_loss=compute_nll_loss, + ) + return LigerFusedLinearPreferenceBase.forward( + ctx, _input, weight, target, bias, loss_fn=orpo_loss_fn + ) + + @staticmethod + def backward(ctx, grad_output): + # Get gradients for _input, weight, bias, and target from the base class + grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] + # Return these gradients, followed by None for the remaining inputs + return *grads, None, None, None, None diff --git a/test/chunked_loss/__init__.py b/test/chunked_loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py new file mode 100644 index 000000000..8bd960c84 --- /dev/null +++ b/test/chunked_loss/test_orpo_loss.py @@ -0,0 +1,237 @@ +from test.utils import assert_verbose_allclose, set_seed +from typing import Tuple + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction + +# set random seed globally +set_seed() + + +class HF_ORPO_Loss: + """ + Implementation of the Odds Ratio Preference Optimization (ORPO) loss, + adapted from Hugging Face's implementation. + Reference: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py + """ + + def __init__(self, ignore_index: int = -100, beta: float = 0.1): + self.ignore_index = ignore_index + self.beta = beta + + def get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of ignore_index are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError( + "Logits (batch and sequence length dim) and labels must have the same shape." + ) + + loss_mask = labels != self.ignore_index + + # dummy token; we'll ignore the losses on these tokens later + labels = torch.where(labels == self.ignore_index, 0, labels) + + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def odds_ratio_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the ORPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes. + The `log(sigmoid(log_odds_chosen))` for logging purposes. + """ + + # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + torch.log1p(-torch.exp(policy_chosen_logps)) + - torch.log1p(-torch.exp(policy_rejected_logps)) + ) + ratio = F.logsigmoid(log_odds) + losses = self.beta * ratio + + return losses + + def concatenated_forward( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ) -> Tuple[ + torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor + ]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + len_chosen = _input.shape[0] // 2 + + outputs = _input @ weight.t() + if bias is not None: + outputs = outputs + bias + all_logits = outputs.float() + + def cross_entropy_loss(logits, labels): + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = target + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], labels[:len_chosen] + ) + + all_logps = self.get_batch_logps( + all_logits, + target, + average_log_prob=True, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) + + def get_batch_loss_metrics( + self, + _input: torch.FloatTensor, + weight: torch.FloatTensor, + target: torch.LongTensor, + bias: torch.FloatTensor = None, + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + + forward_output = self.concatenated_forward(_input, weight, target, bias) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + + losses = self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps) + # full ORPO loss + loss = policy_nll_loss - losses.mean() + return loss + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) +def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta): + B = 2 * B # orpo loss requires B to be even + + _input = torch.randn(B, T, H, device="cuda", dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + target = torch.randint( + 0, + V, + ( + B, + T, + ), + device="cuda", + dtype=torch.long, + ) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() + indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] + target.view(-1)[indices_to_assign] = ignore_index + + _weight = torch.randn(V, H, device="cuda", dtype=dtype) + weight1 = _weight.detach().clone().requires_grad_(True) + weight2 = _weight.detach().clone().requires_grad_(True) + + _bias = torch.randn(V, device="cuda", dtype=dtype) if bias else None + bias1 = _bias.detach().clone().requires_grad_(True) if bias else None + bias2 = _bias.detach().clone().requires_grad_(True) if bias else None + + loss1 = HF_ORPO_Loss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics( + input1, weight1, target, bias1 + ) + loss2 = LigerFusedLinearORPOFunction.apply( + input2, weight2, target, bias2, ignore_index, beta, True + ) + + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose(weight1.grad, weight2.grad, atol=atol, rtol=rtol) + if bias: + assert_verbose_allclose(bias1.grad, bias2.grad, atol=atol, rtol=rtol)