-
Notifications
You must be signed in to change notification settings - Fork 468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding reverse and symmetric KLD losses #2094
Open
insop
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
insop:insop/kld
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import pytest | ||
import torch | ||
from tests.test_utils import assert_expected | ||
from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss | ||
from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss | ||
from torchtune.training.seed import set_seed | ||
|
||
|
||
|
@@ -114,3 +114,201 @@ def test_forward_kl_loss_expected(self): | |
# assert | ||
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) | ||
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) | ||
|
||
class TestReverseKLWithChunkedOutputLoss: | ||
def test_reverse_kl_loss(self): | ||
# Create a sample input and label | ||
ignore_index = -100 | ||
batch_size = 3 | ||
num_tokens = 50 | ||
vocab_size = 50 | ||
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16) | ||
teacher_logits = torch.randn( | ||
batch_size, num_tokens, vocab_size, dtype=torch.bfloat16 | ||
) | ||
labels = torch.randint( | ||
0, vocab_size, (batch_size, num_tokens), dtype=torch.long | ||
) | ||
|
||
# add random ignore index to random tokens in the label | ||
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) | ||
labels[random_indices < num_tokens // 5] = ignore_index | ||
|
||
# chunked RKL | ||
chunked_rkl_loss = ReverseKLWithChunkedOutputLoss( | ||
num_output_chunks=8, ignore_index=ignore_index | ||
) | ||
logits_chunks = logits.chunk(chunked_rkl_loss.num_output_chunks, dim=1) | ||
teacher_logits_chunks = teacher_logits.chunk( | ||
chunked_rkl_loss.num_output_chunks, dim=1 | ||
) | ||
chunked_loss = chunked_rkl_loss(logits_chunks, teacher_logits_chunks, labels) | ||
|
||
# vanilla RKL | ||
rkl_loss = ReverseKLLoss(ignore_index=ignore_index) | ||
logits = logits.reshape(-1, logits.size(-1)) | ||
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) | ||
labels = labels.reshape(-1) | ||
standard_loss = rkl_loss(logits, teacher_logits, labels) | ||
|
||
# Assert | ||
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) | ||
|
||
def test_reverse_kl_loss_expected(self): | ||
student_logits = torch.tensor( | ||
[ | ||
[ | ||
[1.1250, -0.4102, -0.0879, -2.5000], | ||
[0.2676, 0.3535, 0.8711, -1.4688], | ||
[-0.1084, 1.6641, 0.0084, 0.1196], | ||
[0.5000, -0.6406, -0.2236, -1.5938], | ||
], | ||
[ | ||
[-1.5312, -1.9219, 0.0000, -0.5039], | ||
[-1.5391, 1.5312, 0.5820, 0.2695], | ||
[-0.3887, 1.2188, 0.0000, 0.6055], | ||
[0.5000, 1.3828, 0.1309, -1.0312], | ||
], | ||
], | ||
dtype=torch.bfloat16, | ||
) | ||
teacher_logits = torch.tensor( | ||
[ | ||
[ | ||
[-0.0381, -1.2578, -1.2031, 0.0947], | ||
[-0.7852, 0.4492, 1.5547, 0.0972], | ||
[0.8203, 0.0012, 0.7656, 0.3477], | ||
[-1.5781, 0.4297, 0.5977, 0.3926], | ||
], | ||
[ | ||
[1.5156, 0.1641, 2.0781, -0.7734], | ||
[-0.5898, 0.4453, -0.7969, 0.6328], | ||
[0.6289, -0.8359, 0.9258, 0.2109], | ||
[0.0006, 0.5195, 3.2344, -1.5781], | ||
], | ||
], | ||
dtype=torch.bfloat16, | ||
) | ||
labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]]) | ||
expected_loss = torch.tensor(0.6775, dtype=torch.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to verify: did you set this value based on the reference implementation in distillm? |
||
|
||
# chunked RKL loss | ||
chunked_rkl_loss = ReverseKLWithChunkedOutputLoss( | ||
num_output_chunks=2, ignore_index=-100 | ||
) | ||
student_logits_chunks = student_logits.chunk( | ||
chunked_rkl_loss.num_output_chunks, dim=1 | ||
) | ||
teacher_logits_chunks = teacher_logits.chunk( | ||
chunked_rkl_loss.num_output_chunks, dim=1 | ||
) | ||
chunked_loss = chunked_rkl_loss( | ||
student_logits_chunks, teacher_logits_chunks, labels | ||
) | ||
|
||
# vanilla RKL loss | ||
rkl_loss = ReverseKLLoss(ignore_index=-100) | ||
standard_loss = rkl_loss(student_logits, teacher_logits, labels) | ||
|
||
# assert | ||
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) | ||
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) | ||
|
||
class TestSymmetricKLWithChunkedOutputLoss: | ||
def test_symmetric_kl_loss(self): | ||
# Create a sample input and label | ||
ignore_index = -100 | ||
batch_size = 3 | ||
num_tokens = 50 | ||
vocab_size = 50 | ||
logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16) | ||
teacher_logits = torch.randn( | ||
batch_size, num_tokens, vocab_size, dtype=torch.bfloat16 | ||
) | ||
labels = torch.randint( | ||
0, vocab_size, (batch_size, num_tokens), dtype=torch.long | ||
) | ||
|
||
# add random ignore index to random tokens in the label | ||
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) | ||
labels[random_indices < num_tokens // 5] = ignore_index | ||
|
||
# chunked Symmetric KL | ||
chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss( | ||
num_output_chunks=8, ignore_index=ignore_index | ||
) | ||
logits_chunks = logits.chunk(chunked_sym_kl_loss.num_output_chunks, dim=1) | ||
teacher_logits_chunks = teacher_logits.chunk( | ||
chunked_sym_kl_loss.num_output_chunks, dim=1 | ||
) | ||
chunked_loss = chunked_sym_kl_loss(logits_chunks, teacher_logits_chunks, labels) | ||
|
||
# vanilla Symmetric KL | ||
sym_kl_loss = SymmetricKLLoss(ignore_index=ignore_index) | ||
logits = logits.reshape(-1, logits.size(-1)) | ||
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) | ||
labels = labels.reshape(-1) | ||
standard_loss = sym_kl_loss(logits, teacher_logits, labels) | ||
|
||
# Assert | ||
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) | ||
|
||
def test_symmetric_kl_loss_expected(self): | ||
student_logits = torch.tensor( | ||
[ | ||
[ | ||
[1.1250, -0.4102, -0.0879, -2.5000], | ||
[0.2676, 0.3535, 0.8711, -1.4688], | ||
[-0.1084, 1.6641, 0.0084, 0.1196], | ||
[0.5000, -0.6406, -0.2236, -1.5938], | ||
], | ||
[ | ||
[-1.5312, -1.9219, 0.0000, -0.5039], | ||
[-1.5391, 1.5312, 0.5820, 0.2695], | ||
[-0.3887, 1.2188, 0.0000, 0.6055], | ||
[0.5000, 1.3828, 0.1309, -1.0312], | ||
], | ||
], | ||
dtype=torch.bfloat16, | ||
) | ||
teacher_logits = torch.tensor( | ||
[ | ||
[ | ||
[-0.0381, -1.2578, -1.2031, 0.0947], | ||
[-0.7852, 0.4492, 1.5547, 0.0972], | ||
[0.8203, 0.0012, 0.7656, 0.3477], | ||
[-1.5781, 0.4297, 0.5977, 0.3926], | ||
], | ||
[ | ||
[1.5156, 0.1641, 2.0781, -0.7734], | ||
[-0.5898, 0.4453, -0.7969, 0.6328], | ||
[0.6289, -0.8359, 0.9258, 0.2109], | ||
[0.0006, 0.5195, 3.2344, -1.5781], | ||
], | ||
], | ||
dtype=torch.bfloat16, | ||
) | ||
labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]]) | ||
expected_loss = torch.tensor(1.1992, dtype=torch.float32) | ||
|
||
# chunked Symmetric KL loss | ||
chunked_sym_kl_loss = SymmetricKLWithChunkedOutputLoss( | ||
num_output_chunks=2, ignore_index=-100 | ||
) | ||
student_logits_chunks = student_logits.chunk( | ||
chunked_sym_kl_loss.num_output_chunks, dim=1 | ||
) | ||
teacher_logits_chunks = teacher_logits.chunk( | ||
chunked_sym_kl_loss.num_output_chunks, dim=1 | ||
) | ||
chunked_loss = chunked_sym_kl_loss( | ||
student_logits_chunks, teacher_logits_chunks, labels | ||
) | ||
|
||
# vanilla Symmetric KL loss | ||
sym_kl_loss = SymmetricKLLoss(ignore_index=-100) | ||
standard_loss = sym_kl_loss(student_logits, teacher_logits, labels) | ||
|
||
# assert | ||
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) | ||
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these unit tests!