Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding reverse and symmetric KLD losses #2094

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 199 additions & 1 deletion tests/torchtune/modules/loss/test_kd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch
from tests.test_utils import assert_expected
from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss
from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss
from torchtune.training.seed import set_seed


Expand Down Expand Up @@ -114,3 +114,201 @@ def test_forward_kl_loss_expected(self):
# assert
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)

class TestReverseKLWithChunkedOutputLoss:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding these unit tests!

def test_reverse_kl_loss(self):
# Create a sample input and label
ignore_index = -100
batch_size = 3
num_tokens = 50
vocab_size = 50
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16)
teacher_logits = torch.randn(
batch_size, num_tokens, vocab_size, dtype=torch.bfloat16
)
labels = torch.randint(
0, vocab_size, (batch_size, num_tokens), dtype=torch.long
)

# add random ignore index to random tokens in the label
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens))
labels[random_indices < num_tokens // 5] = ignore_index

# chunked RKL
chunked_rkl_loss = ReverseKLWithChunkedOutputLoss(
num_output_chunks=8, ignore_index=ignore_index
)
logits_chunks = logits.chunk(chunked_rkl_loss.num_output_chunks, dim=1)
teacher_logits_chunks = teacher_logits.chunk(
chunked_rkl_loss.num_output_chunks, dim=1
)
chunked_loss = chunked_rkl_loss(logits_chunks, teacher_logits_chunks, labels)

# vanilla RKL
rkl_loss = ReverseKLLoss(ignore_index=ignore_index)
logits = logits.reshape(-1, logits.size(-1))
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1))
labels = labels.reshape(-1)
standard_loss = rkl_loss(logits, teacher_logits, labels)

# Assert
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)

def test_reverse_kl_loss_expected(self):
student_logits = torch.tensor(
[
[
[1.1250, -0.4102, -0.0879, -2.5000],
[0.2676, 0.3535, 0.8711, -1.4688],
[-0.1084, 1.6641, 0.0084, 0.1196],
[0.5000, -0.6406, -0.2236, -1.5938],
],
[
[-1.5312, -1.9219, 0.0000, -0.5039],
[-1.5391, 1.5312, 0.5820, 0.2695],
[-0.3887, 1.2188, 0.0000, 0.6055],
[0.5000, 1.3828, 0.1309, -1.0312],
],
],
dtype=torch.bfloat16,
)
teacher_logits = torch.tensor(
[
[
[-0.0381, -1.2578, -1.2031, 0.0947],
[-0.7852, 0.4492, 1.5547, 0.0972],
[0.8203, 0.0012, 0.7656, 0.3477],
[-1.5781, 0.4297, 0.5977, 0.3926],
],
[
[1.5156, 0.1641, 2.0781, -0.7734],
[-0.5898, 0.4453, -0.7969, 0.6328],
[0.6289, -0.8359, 0.9258, 0.2109],
[0.0006, 0.5195, 3.2344, -1.5781],
],
],
dtype=torch.bfloat16,
)
labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]])
expected_loss = torch.tensor(0.6775, dtype=torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to verify: did you set this value based on the reference implementation in distillm?


# chunked RKL loss
chunked_rkl_loss = ReverseKLWithChunkedOutputLoss(
num_output_chunks=2, ignore_index=-100
)
student_logits_chunks = student_logits.chunk(
chunked_rkl_loss.num_output_chunks, dim=1
)
teacher_logits_chunks = teacher_logits.chunk(
chunked_rkl_loss.num_output_chunks, dim=1
)
chunked_loss = chunked_rkl_loss(
student_logits_chunks, teacher_logits_chunks, labels
)

# vanilla RKL loss
rkl_loss = ReverseKLLoss(ignore_index=-100)
standard_loss = rkl_loss(student_logits, teacher_logits, labels)

# assert
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)

class TestSymmetricKLWithChunkedOutputLoss:
def test_symmetric_kl_loss(self):
# Create a sample input and label
ignore_index = -100
batch_size = 3
num_tokens = 50
vocab_size = 50
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16)
teacher_logits = torch.randn(
batch_size, num_tokens, vocab_size, dtype=torch.bfloat16
)
labels = torch.randint(
0, vocab_size, (batch_size, num_tokens), dtype=torch.long
)

# add random ignore index to random tokens in the label
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens))
labels[random_indices < num_tokens // 5] = ignore_index

# chunked Symmetric KL
chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss(
num_output_chunks=8, ignore_index=ignore_index
)
logits_chunks = logits.chunk(chunked_sym_kl_loss.num_output_chunks, dim=1)
teacher_logits_chunks = teacher_logits.chunk(
chunked_sym_kl_loss.num_output_chunks, dim=1
)
chunked_loss = chunked_sym_kl_loss(logits_chunks, teacher_logits_chunks, labels)

# vanilla Symmetric KL
sym_kl_loss = SymmetricKLLoss(ignore_index=ignore_index)
logits = logits.reshape(-1, logits.size(-1))
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1))
labels = labels.reshape(-1)
standard_loss = sym_kl_loss(logits, teacher_logits, labels)

# Assert
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)

def test_symmetric_kl_loss_expected(self):
student_logits = torch.tensor(
[
[
[1.1250, -0.4102, -0.0879, -2.5000],
[0.2676, 0.3535, 0.8711, -1.4688],
[-0.1084, 1.6641, 0.0084, 0.1196],
[0.5000, -0.6406, -0.2236, -1.5938],
],
[
[-1.5312, -1.9219, 0.0000, -0.5039],
[-1.5391, 1.5312, 0.5820, 0.2695],
[-0.3887, 1.2188, 0.0000, 0.6055],
[0.5000, 1.3828, 0.1309, -1.0312],
],
],
dtype=torch.bfloat16,
)
teacher_logits = torch.tensor(
[
[
[-0.0381, -1.2578, -1.2031, 0.0947],
[-0.7852, 0.4492, 1.5547, 0.0972],
[0.8203, 0.0012, 0.7656, 0.3477],
[-1.5781, 0.4297, 0.5977, 0.3926],
],
[
[1.5156, 0.1641, 2.0781, -0.7734],
[-0.5898, 0.4453, -0.7969, 0.6328],
[0.6289, -0.8359, 0.9258, 0.2109],
[0.0006, 0.5195, 3.2344, -1.5781],
],
],
dtype=torch.bfloat16,
)
labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]])
expected_loss = torch.tensor(1.1992, dtype=torch.float32)

# chunked Symmetric KL loss
chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss(
num_output_chunks=2, ignore_index=-100
)
student_logits_chunks = student_logits.chunk(
chunked_sym_kl_loss.num_output_chunks, dim=1
)
teacher_logits_chunks = teacher_logits.chunk(
chunked_sym_kl_loss.num_output_chunks, dim=1
)
chunked_loss = chunked_sym_kl_loss(
student_logits_chunks, teacher_logits_chunks, labels
)

# vanilla Symmetric KL loss
sym_kl_loss = SymmetricKLLoss(ignore_index=-100)
standard_loss = sym_kl_loss(student_logits, teacher_logits, labels)

# assert
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
6 changes: 5 additions & 1 deletion torchtune/modules/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
# LICENSE file in the root directory of this source tree.

from .ce_chunked_output_loss import CEWithChunkedOutputLoss
from .kd_losses import ForwardKLLoss, ForwardKLWithChunkedOutputLoss
from .kd_losses import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss

__all__ = [
"CEWithChunkedOutputLoss",
"ForwardKLLoss",
"ForwardKLWithChunkedOutputLoss",
"ReverseKLLoss",
"ReverseKLWithChunkedOutputLoss",
"SymmetricKLLoss",
"SymmetricKLWithChunkedOutputLoss",
]
Loading
Loading