From 1c0c75c3455e788d575966bfc5edec3ef166835e Mon Sep 17 00:00:00 2001 From: Yun Dai Date: Tue, 29 Oct 2024 21:59:37 -0700 Subject: [PATCH] fix fused JSD with ignore index (#330) ## Summary 1. There's currently a bug in fused linear JSD where we don't extract the correct subset of label corresponding to the currently processed chunk 2. add some tests to make sure results are correct when all tokens are ignored ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --- src/liger_kernel/ops/fused_linear_jsd.py | 4 +- src/liger_kernel/ops/jsd.py | 2 +- test/transformers/test_fused_linear_jsd.py | 74 ++++++++++++++++++++++ test/transformers/test_jsd.py | 47 +++++++++++++- 4 files changed, 124 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py index 34cb185c1..599b701d2 100644 --- a/src/liger_kernel/ops/fused_linear_jsd.py +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -92,7 +92,9 @@ def fused_linear_jsd_forward( dX_ptr=student_prob_chunk, dX_stride=student_prob_chunk.stride(-2), label_ptr=( - shift_labels if has_label else torch.empty(1, device=device) + shift_labels[start_idx:end_idx] + if has_label + else torch.empty(1, device=device) ), # dummy ptr if no label beta=jsd_beta, n_non_ignore=n_non_ignore, diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py index 33ec2498c..6ecf8dbe9 100644 --- a/src/liger_kernel/ops/jsd.py +++ b/src/liger_kernel/ops/jsd.py @@ -19,7 +19,7 @@ def _jsd_kernel( dX_stride, label_ptr, beta, - n_non_ignore, + n_non_ignore: int, ignore_index: tl.constexpr, n_cols, BLOCK_SIZE: tl.constexpr, diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py index 2024e054a..6da2b6e54 100644 --- a/test/transformers/test_fused_linear_jsd.py +++ b/test/transformers/test_fused_linear_jsd.py @@ -198,6 +198,7 @@ def test_correctness_with_ignore_index( dtype=dtype, device=device, temperature=temperature, + ignore_index=ignore_index, beta=beta, ).to(device) liger_lm_head_jsd = LigerLMHeadJSD( @@ -206,6 +207,7 @@ def test_correctness_with_ignore_index( dtype=dtype, device=device, temperature=temperature, + ignore_index=ignore_index, beta=beta, ).to(device) @@ -329,3 +331,75 @@ def test_correctness_functional( assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) assert torch.allclose(_weight1.grad, _weight2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 4, 2048, 3200), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, beta, ignore_index", + [ + (1.0, 0.5, 2), + (2.0, 0.1, 42), + ], +) +def test_correctness_all_ignored( + B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol +): + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + ignore_index=ignore_index, + beta=beta, + ).to(device) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + ignore_index=ignore_index, + beta=beta, + ).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + label = torch.full((B * T,), ignore_index, device=device, dtype=torch.long) + + output1 = torch_lm_head_jsd(_input1, teacher_input, label) + output2 = liger_lm_head_jsd(_input2, teacher_input, label) + + assert torch.allclose(output1, output2, atol=atol, rtol=rtol) + assert torch.allclose(output2, torch.zeros_like(output2), atol=atol, rtol=rtol) + + output2.backward() + + assert torch.allclose( + torch.zeros_like(_input2.grad), _input2.grad, atol=atol, rtol=rtol + ) diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index 37e12180e..220e87271 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -41,7 +41,7 @@ def forward( loss = torch.where(label != self.ignore_index, loss, 0.0) n_non_ignore = (label != self.ignore_index).sum().item() if n_non_ignore == 0: - loss = 0.0 + loss = torch.tensor(0.0).to(loss.device) else: loss = (loss / n_non_ignore).sum() else: @@ -294,3 +294,48 @@ def test_correctness_functional( _test_correctness_functional( B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol ) + + +# @pytest.mark.parametrize(*_SHAPE_PARAMS) +def test_correctness_with_all_indices_ignored( + B=2, + T=10, + V=32, + dtype=torch.bfloat16, + atol=1e-3, + rtol=1e-3, + device="cuda", +): + ignore_index = -100 + torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) + liger_jsd = LigerJSD(ignore_index=ignore_index) + + inp = torch.randn( + B * T, V, device=device, dtype=dtype, requires_grad=True + ).log_softmax(dim=-1) + + x1 = inp.detach().clone().requires_grad_(True) + x2 = inp.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1) + + # label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + label = torch.full((B * T,), ignore_index, device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output = torch_jsd(x1, target, label) + output2 = liger_jsd(x2, target, label) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) + assert_verbose_allclose(torch.zeros_like(output2), output2, atol=atol, rtol=rtol) + + output2.backward() + assert_verbose_allclose(torch.zeros_like(x2.grad), x2.grad, atol=atol, rtol=rtol)