Skip to content

Commit

Permalink
fix: TextEnvironment cache combination and batching issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Konrad Gerlach committed Jan 10, 2025
1 parent af06d63 commit f6f12b5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 24 deletions.
44 changes: 32 additions & 12 deletions tests/test_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,24 @@ def test_combine_cache(self):
)

caches = [
((torch.tensor([[1], [2]]), torch.tensor([[3], [4]])),),
((torch.tensor([[5]]), torch.tensor([[6]])),),
(
(torch.tensor([[1], [2]]), torch.tensor([[3], [4]])),
(torch.tensor([[7], [8]]), torch.tensor([[9], [10]])),
),
(
(torch.tensor([[5]]), torch.tensor([[6]])),
(torch.tensor([[11]]), torch.tensor([[12]])),
),
]
caches = [DynamicCache().from_legacy_cache(cache) for cache in caches]
attention_masks = [torch.tensor([[0], [1]]), torch.tensor([[2]])]
input_ids = [torch.tensor([[1], [2]]), torch.tensor([[3]])]
example_mask = [True, False, True]

expected_cache = ((torch.tensor([[1], [5]]), torch.tensor([[3], [6]])),)
expected_cache = (
(torch.tensor([[1], [5]]), torch.tensor([[3], [6]])),
(torch.tensor([[7], [11]]), torch.tensor([[9], [12]])),
)
expected_attention_mask = torch.tensor([[0], [2]])
expected_input_ids = torch.tensor([[1], [3]])

Expand All @@ -311,6 +320,9 @@ def test_combine_cache(self):
self.assertEqual(len(combined_cache[0]), len(expected_cache[0]))
self.assertTrue(torch.all(combined_cache[0][0] == expected_cache[0][0]))
self.assertTrue(torch.all(combined_cache[0][1] == expected_cache[0][1]))
self.assertEqual(len(combined_cache[1]), len(expected_cache[1]))
self.assertTrue(torch.all(combined_cache[1][0] == expected_cache[1][0]))
self.assertTrue(torch.all(combined_cache[1][1] == expected_cache[1][1]))
self.assertTrue(torch.all(combined_attention_masks == expected_attention_mask))
self.assertTrue(torch.all(combined_input_ids == expected_input_ids))

Expand All @@ -324,19 +336,28 @@ def test_get_batched_cache(self):
max_turns=2,
)

cache = ((torch.tensor([[1], [2], [3]]), torch.tensor([[4], [5], [6]])),)
cache = (
(torch.tensor([[1], [2], [3]]), torch.tensor([[4], [5], [6]])),
(torch.tensor([[7], [8], [9]]), torch.tensor([[10], [11], [12]])),
)
attention_masks = torch.tensor([[1], [2], [3]])
input_ids = torch.tensor([[4], [5], [6]])
batched_cache, batched_attention_masks, batched_input_ids = env._get_batched_cache(
1, 3, cache, attention_masks, input_ids
)
batched_cache = batched_cache.to_legacy_cache()
expected_cache = ((torch.tensor([[2], [3]]), torch.tensor([[5], [6]])),)
expected_cache = (
(torch.tensor([[2], [3]]), torch.tensor([[5], [6]])),
(torch.tensor([[8], [9]]), torch.tensor([[11], [12]])),
)

self.assertEqual(len(batched_cache), len(expected_cache))
self.assertEqual(len(batched_cache[0]), len(expected_cache[0]))
self.assertTrue(torch.all(batched_cache[0][0] == expected_cache[0][0]))
self.assertTrue(torch.all(batched_cache[0][1] == expected_cache[0][1]))
self.assertEqual(len(batched_cache[1]), len(expected_cache[1]))
self.assertTrue(torch.all(batched_cache[1][0] == expected_cache[1][0]))
self.assertTrue(torch.all(batched_cache[1][1] == expected_cache[1][1]))

expected_attention_mask = torch.tensor([[2], [3]])
self.assertTrue(torch.all(batched_attention_masks == expected_attention_mask))
Expand All @@ -355,16 +376,17 @@ def test_cached_generate_batched(self):
generation_kwargs=generation_kwargs,
)

input_texts = ["this is a test", "this is another, longer test"]
input_texts = ["this is a test", "this is another, longer test", "some other batch"]
model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched(
model_inputs, batch_size=2
)
past_key_values = past_key_values[0].to_legacy_cache()
past_attention_masks = past_attention_masks[0]
past_input_ids = past_input_ids[0]

input_texts2 = [" short interim", " a slightly longer interim"]
past_key_values, past_attention_masks, past_input_ids = env._combine_cache(
[True, True, True], past_key_values, past_attention_masks, past_input_ids
)

input_texts2 = [" short interim", " a slightly longer interim", "another interim"]
model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2]

outputs_cached, _, _, _, _ = env._generate_batched(
Expand All @@ -378,8 +400,6 @@ def test_cached_generate_batched(self):
model_inputs2_full = [
torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2)
]

outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2)

for cached, uncached in zip(outputs_cached, outputs_uncached):
self.assertTrue(torch.all(cached == uncached))
31 changes: 19 additions & 12 deletions trl/environment/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,10 @@ def _combine_cache(self, example_mask, past_key_values, past_attention_masks, pa
past_input_ids (list[torch.Tensor]): Batched list of input ids from the last generation
"""
legacy_format = [cache.to_legacy_cache() for cache in past_key_values]
example_mask_offset = 0
combined_cache = []
for layer_id in range(len(legacy_format[0])):
combined_layer = None
example_mask_offset = 0
for cache_idx, cache in enumerate(legacy_format):
layer = cache[layer_id]
num_examples = len(layer[0])
Expand Down Expand Up @@ -580,27 +580,34 @@ def _generate_batched(
new_past_key_values = []
new_past_attention_masks = []
new_past_input_ids = []

# pad all batches to same length for cache compatibility
mask = [torch.ones_like(element) for element in query_tensors]
inputs = {"input_ids": query_tensors, "attention_mask": mask}
all_padded_inputs = self.tokenizer.pad(
inputs,
padding=True,
max_length=None,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)

# in case we have fewer examples than bs
batch_size = min(len(query_tensors), batch_size)
for batch_index, i in enumerate(range(0, len(query_tensors), batch_size)):
# prevent overflow if query tensors are not even multiple of bs
end_index = min(len(query_tensors), i + batch_size)
batch = query_tensors[i:end_index]
batch_mask = [torch.ones_like(element) for element in batch]
past_key_values, past_attention_masks, past_input_ids = (None, None, None)
if combined_past_key_values is not None:
past_key_values, past_attention_masks, past_input_ids = self._get_batched_cache(
i, end_index, combined_past_key_values, combined_past_attention_masks, combined_past_input_ids
)
inputs = {"input_ids": batch, "attention_mask": batch_mask}

padded_inputs = self.tokenizer.pad(
inputs,
padding=True,
max_length=None,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)

padded_inputs = {
"input_ids": all_padded_inputs["input_ids"][i:end_index],
"attention_mask": all_padded_inputs["attention_mask"][i:end_index],
}

input_attention_mask = padded_inputs["attention_mask"].clone()
stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer)

Expand Down

0 comments on commit f6f12b5

Please sign in to comment.