Skip to content

Commit

Permalink
fix fused JSD with ignore index (#330)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
1. There's currently a bug in fused linear JSD where we don't extract
the correct subset of label corresponding to the currently processed
chunk
2. add some tests to make sure results are correct when all tokens are
ignored
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
yundai424 authored Oct 30, 2024
1 parent 6cdc93d commit 1c0c75c
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/liger_kernel/ops/fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def fused_linear_jsd_forward(
dX_ptr=student_prob_chunk,
dX_stride=student_prob_chunk.stride(-2),
label_ptr=(
shift_labels if has_label else torch.empty(1, device=device)
shift_labels[start_idx:end_idx]
if has_label
else torch.empty(1, device=device)
), # dummy ptr if no label
beta=jsd_beta,
n_non_ignore=n_non_ignore,
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/ops/jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _jsd_kernel(
dX_stride,
label_ptr,
beta,
n_non_ignore,
n_non_ignore: int,
ignore_index: tl.constexpr,
n_cols,
BLOCK_SIZE: tl.constexpr,
Expand Down
74 changes: 74 additions & 0 deletions test/transformers/test_fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_correctness_with_ignore_index(
dtype=dtype,
device=device,
temperature=temperature,
ignore_index=ignore_index,
beta=beta,
).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(
Expand All @@ -206,6 +207,7 @@ def test_correctness_with_ignore_index(
dtype=dtype,
device=device,
temperature=temperature,
ignore_index=ignore_index,
beta=beta,
).to(device)

Expand Down Expand Up @@ -329,3 +331,75 @@ def test_correctness_functional(
assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)

assert torch.allclose(_weight1.grad, _weight2.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 4, 2048, 3200),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-3, 5e-2),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize(
"temperature, beta, ignore_index",
[
(1.0, 0.5, 2),
(2.0, 0.1, 42),
],
)
def test_correctness_all_ignored(
B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol
):
device = "cuda"
torch_lm_head_jsd = TorchLMHeadJSD(
H=H,
V=V,
dtype=dtype,
device=device,
temperature=temperature,
ignore_index=ignore_index,
beta=beta,
).to(device)
liger_lm_head_jsd = LigerLMHeadJSD(
H=H,
V=V,
dtype=dtype,
device=device,
temperature=temperature,
ignore_index=ignore_index,
beta=beta,
).to(device)

# init the linear in all FusedLinearJSDs with the same weights
torch_lm_head_jsd.student_lin.weight.data = (
liger_lm_head_jsd.student_lin.weight.data
) = torch.rand(V, H // 2, device=device, dtype=dtype)
torch_lm_head_jsd.teacher_lin.weight.data = (
liger_lm_head_jsd.teacher_lin.weight.data
) = torch.rand(V, H, device=device, dtype=dtype)

_tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar
_input1 = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar

label = torch.full((B * T,), ignore_index, device=device, dtype=torch.long)

output1 = torch_lm_head_jsd(_input1, teacher_input, label)
output2 = liger_lm_head_jsd(_input2, teacher_input, label)

assert torch.allclose(output1, output2, atol=atol, rtol=rtol)
assert torch.allclose(output2, torch.zeros_like(output2), atol=atol, rtol=rtol)

output2.backward()

assert torch.allclose(
torch.zeros_like(_input2.grad), _input2.grad, atol=atol, rtol=rtol
)
47 changes: 46 additions & 1 deletion test/transformers/test_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def forward(
loss = torch.where(label != self.ignore_index, loss, 0.0)
n_non_ignore = (label != self.ignore_index).sum().item()
if n_non_ignore == 0:
loss = 0.0
loss = torch.tensor(0.0).to(loss.device)
else:
loss = (loss / n_non_ignore).sum()
else:
Expand Down Expand Up @@ -294,3 +294,48 @@ def test_correctness_functional(
_test_correctness_functional(
B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol
)


# @pytest.mark.parametrize(*_SHAPE_PARAMS)
def test_correctness_with_all_indices_ignored(
B=2,
T=10,
V=32,
dtype=torch.bfloat16,
atol=1e-3,
rtol=1e-3,
device="cuda",
):
ignore_index = -100
torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype)
liger_jsd = LigerJSD(ignore_index=ignore_index)

inp = torch.randn(
B * T, V, device=device, dtype=dtype, requires_grad=True
).log_softmax(dim=-1)

x1 = inp.detach().clone().requires_grad_(True)
x2 = inp.detach().clone().requires_grad_(True)

with torch.no_grad():
target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1)

# label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
label = torch.full((B * T,), ignore_index, device=device, dtype=torch.long)

# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(
1, B * T // 2, (1,)
).item() # Random number of elements to set to ignore_index
indices_to_assign = torch.randperm(B * T)[
:num_elements_to_assign
] # Randomly select indices
label[indices_to_assign] = ignore_index

output = torch_jsd(x1, target, label)
output2 = liger_jsd(x2, target, label)
assert_verbose_allclose(output, output2, atol=atol, rtol=rtol)
assert_verbose_allclose(torch.zeros_like(output2), output2, atol=atol, rtol=rtol)

output2.backward()
assert_verbose_allclose(torch.zeros_like(x2.grad), x2.grad, atol=atol, rtol=rtol)

0 comments on commit 1c0c75c

Please sign in to comment.