Skip to content

Commit

Permalink
fix dpo tests: reduce tolerance and change default compute_nll_loss f…
Browse files Browse the repository at this point in the history
…alse (#490)

## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
shivam15s authored Dec 20, 2024
1 parent 7a781b7 commit 3205342
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(
ref_bias=None,
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compute_nll_loss=False,
compiled=True,
use_ref_model=True,
):
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compute_nll_loss: bool = False,
compiled: bool = True,
use_ref_model: bool = False,
):
Expand Down
69 changes: 60 additions & 9 deletions test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,17 @@ class HFDPOLoss(HFAlignmentLoss):
"""

def __init__(
self, ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True
self,
ignore_index: int = -100,
beta: float = 0.1,
use_ref_model: bool = True,
compute_nll_loss: bool = False,
):
super().__init__(
beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model
beta=beta,
ignore_index=ignore_index,
use_ref_model=use_ref_model,
compute_nll_loss=compute_nll_loss,
)

def alignment_loss(
Expand Down Expand Up @@ -61,6 +68,7 @@ def __init__(
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
compute_nll_loss: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
Expand All @@ -72,7 +80,10 @@ def __init__(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.dpo_loss = HFDPOLoss(
ignore_index=ignore_index, beta=beta, use_ref_model=True
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
compute_nll_loss=compute_nll_loss,
).get_batch_loss_metrics

def forward(self, x, ref_x, y):
Expand All @@ -95,6 +106,7 @@ def __init__(
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
compute_nll_loss: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
Expand All @@ -106,7 +118,10 @@ def __init__(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.dpo_loss = LigerFusedLinearDPOLoss(
ignore_index=ignore_index, beta=beta, use_ref_model=True
ignore_index=ignore_index,
beta=beta,
use_ref_model=True,
compute_nll_loss=compute_nll_loss,
)

def forward(self, x, ref_x, y):
Expand All @@ -132,14 +147,27 @@ def forward(self, x, ref_x, y):
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-1),
(1.0, torch.float32, 2e-2, 5e-1),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ref_bias", [True, False])
@pytest.mark.parametrize("compute_nll_loss", [True, False])
@pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)])
def test_correctness(
B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta
B,
T,
H,
V,
scalar,
dtype,
atol,
rtol,
bias,
ref_bias,
compute_nll_loss,
ignore_index,
beta,
):
B = 2 * B # dpo loss requires B to be even

Expand All @@ -149,6 +177,7 @@ def test_correctness(
dtype=dtype,
bias=bias,
ref_bias=ref_bias,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
)
Expand All @@ -158,6 +187,7 @@ def test_correctness(
dtype=dtype,
bias=bias,
ref_bias=ref_bias,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
)
Expand Down Expand Up @@ -251,7 +281,10 @@ def test_correctness(
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ref_bias", [True, False])
def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias):
@pytest.mark.parametrize("compute_nll_loss", [True, False])
def test_correctness_functional(
B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss
):
B = 2 * B

_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
Expand Down Expand Up @@ -290,10 +323,28 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref
ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None

loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply(
input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1
input1,
weight1,
target,
bias1,
ref_input,
ref_weight1,
ref_bias1,
-100,
0.1,
compute_nll_loss,
)
loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo(
input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2
input2,
weight2,
target,
bias2,
ref_input,
ref_weight2,
ref_bias2,
-100,
0.1,
compute_nll_loss,
)

assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol)
Expand Down
10 changes: 7 additions & 3 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,13 @@ def __init__(
beta: float = 0.1,
ignore_index: int = -100,
use_ref_model: bool = False,
compute_nll_loss: bool = True,
):
self.alpha = alpha
self.beta = beta
self.ignore_index = ignore_index
self.use_ref_model = use_ref_model
self.compute_nll_loss = compute_nll_loss

@abstractmethod
def alignment_loss(self):
Expand Down Expand Up @@ -448,9 +450,11 @@ def cross_entropy_loss(logits, labels):
return loss

labels = target
chosen_nll_loss = cross_entropy_loss(
all_logits[:len_chosen], labels[:len_chosen]
)
chosen_nll_loss = torch.tensor(0.0, device=all_logits.device)
if self.compute_nll_loss:
chosen_nll_loss = cross_entropy_loss(
all_logits[:len_chosen], labels[:len_chosen]
)

all_logps = self.get_batch_logps(
all_logits,
Expand Down

0 comments on commit 3205342

Please sign in to comment.