Skip to content

Commit

Permalink
IDEFICS: support inputs embeds (#34043)
Browse files Browse the repository at this point in the history
* support embeds

* use cache from config

* style...

* fix tests after rebase
  • Loading branch information
zucchini-nlp authored Oct 16, 2024
1 parent 9d6998c commit d087165
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 18 deletions.
8 changes: 4 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def _get_logits_processor(
self,
unconditional_ids=negative_prompt_ids,
unconditional_attention_mask=negative_prompt_attention_mask,
use_cache=model_kwargs["use_cache"],
use_cache=generation_config.use_cache,
)
)
if generation_config.sequence_bias is not None:
Expand Down Expand Up @@ -2004,10 +2004,7 @@ def generate(
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
generation_config.use_cache = True
else:
model_kwargs["use_cache"] = generation_config.use_cache

if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
Expand Down Expand Up @@ -2116,6 +2113,9 @@ def generate(
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)

# Set model_kwargs `use_cache` so we can use it later in forward runs
model_kwargs["use_cache"] = generation_config.use_cache

# 10. go into different generation modes
if generation_mode == GenerationMode.ASSISTED_GENERATION:
if generation_config.num_return_sequences > 1:
Expand Down
34 changes: 22 additions & 12 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,19 +1663,31 @@ def forward(
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
past_key_values=None,
cache_position=None,
pixel_values=None,
image_hidden_states=None,
image_attention_mask=None,
use_cache=None,
cache_position=None,
**kwargs,
):
model_inputs = {}
if image_hidden_states is not None:
if self.config.use_resampler:
model_inputs["perceiver_embeddings"] = image_hidden_states
else:
model_inputs["image_encoder_embeddings"] = image_hidden_states
else:
model_inputs["pixel_values"] = pixel_values

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]:
if inputs_embeds is not None:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if image_attention_mask is not None:
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
Expand All @@ -1690,19 +1702,17 @@ def prepare_inputs_for_generation(
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

model_inputs = {}
image_hidden_states = kwargs.pop("image_hidden_states", None)
if image_hidden_states is not None:
if self.config.use_resampler:
model_inputs["perceiver_embeddings"] = image_hidden_states
else:
model_inputs["image_encoder_embeddings"] = image_hidden_states
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
else:
model_inputs["pixel_values"] = pixel_values
# The clone here is for the same reason as for `position_ids`.
model_inputs.update(
{"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
)

model_inputs.update(
{
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"cache_position": cache_position,
Expand Down
6 changes: 6 additions & 0 deletions tests/models/idefics/test_modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,12 @@ def test_contrastive_generate_low_memory(self):
def test_custom_4d_attention_mask(self):
pass

@unittest.skip(
reason="IDEFICS has specific requirements for working with inputs embeds like passing also the ids and pixels"
)
def test_generate_from_inputs_embeds_decoder_only(self):
pass

@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
def test_generate_compile_fullgraph(self):
pass
Expand Down
25 changes: 25 additions & 0 deletions tests/models/idefics2/test_modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,31 @@ def test_resize_embeddings_untied(self):
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))

def test_inputs_embeds_matches_input_ids_with_generate(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()

inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1

wte = model.get_input_embeddings()

input_ids = inputs["input_ids"]
# some models infer position ids/attn mask differently when input ids
# by check if pad_token let's make sure no padding is in input ids
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
input_ids[input_ids == pad_token_id] = not_pad_token_id
del inputs["input_ids"]
inputs_embeds = wte(input_ids)
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)

self.assertTrue(torch.allclose(out_embeds, out_ids))


@require_torch
class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
Expand Down
25 changes: 25 additions & 0 deletions tests/models/idefics3/test_modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,31 @@ def test_resize_embeddings_untied(self):
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))

def test_inputs_embeds_matches_input_ids_with_generate(self):
# overwrite because IDEFICS needs ids and embeds at the input to be not None
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()

inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1

wte = model.get_input_embeddings()

input_ids = inputs["input_ids"]
# some models infer position ids/attn mask differently when input ids
# by check if pad_token let's make sure no padding is in input ids
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
input_ids[input_ids == pad_token_id] = not_pad_token_id
del inputs["input_ids"]
inputs_embeds = wte(input_ids)
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)

self.assertTrue(torch.allclose(out_embeds, out_ids))


@require_torch
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
Expand Down
6 changes: 6 additions & 0 deletions tests/models/kosmos2/test_modeling_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,12 @@ def check_same_values(layer_1, layer_2):
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))

@unittest.skip(
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking ipnut args because KOSMOS-2 has `generate()` overwritten"
)
def test_inputs_embeds_matches_input_ids_with_generate(self):
pass

@slow
def test_model_from_pretrained(self):
model_name = "microsoft/kosmos-2-patch14-224"
Expand Down
14 changes: 12 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3000,8 +3000,11 @@ def test_inputs_embeds_matches_input_ids(self):

def test_inputs_embeds_matches_input_ids_with_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES):
for model_class in self.all_model_classes:
if model_class.__name__ not in [
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
]:
continue
model = model_class(config)
model.to(torch_device)
Expand All @@ -3018,6 +3021,13 @@ def test_inputs_embeds_matches_input_ids_with_generate(self):
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1

# VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged
# embeds already
if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES):
inputs.pop("pixel_values", None)
inputs.pop("pixel_values_videos", None)
inputs.pop("pixel_values_images", None)

wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
Expand Down

0 comments on commit d087165

Please sign in to comment.