mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
IDEFICS: support inputs embeds (#34043)
* support embeds * use cache from config * style... * fix tests after rebase
This commit is contained in:
parent
9d6998c759
commit
d087165db0
@ -857,7 +857,7 @@ class GenerationMixin:
|
|||||||
self,
|
self,
|
||||||
unconditional_ids=negative_prompt_ids,
|
unconditional_ids=negative_prompt_ids,
|
||||||
unconditional_attention_mask=negative_prompt_attention_mask,
|
unconditional_attention_mask=negative_prompt_attention_mask,
|
||||||
use_cache=model_kwargs["use_cache"],
|
use_cache=generation_config.use_cache,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if generation_config.sequence_bias is not None:
|
if generation_config.sequence_bias is not None:
|
||||||
@ -2004,10 +2004,7 @@ class GenerationMixin:
|
|||||||
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
|
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
|
||||||
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
||||||
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
||||||
model_kwargs["use_cache"] = True
|
|
||||||
generation_config.use_cache = True
|
generation_config.use_cache = True
|
||||||
else:
|
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
|
||||||
|
|
||||||
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
@ -2116,6 +2113,9 @@ class GenerationMixin:
|
|||||||
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set model_kwargs `use_cache` so we can use it later in forward runs
|
||||||
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
|
||||||
# 10. go into different generation modes
|
# 10. go into different generation modes
|
||||||
if generation_mode == GenerationMode.ASSISTED_GENERATION:
|
if generation_mode == GenerationMode.ASSISTED_GENERATION:
|
||||||
if generation_config.num_return_sequences > 1:
|
if generation_config.num_return_sequences > 1:
|
||||||
|
@ -1663,19 +1663,31 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
|||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=None,
|
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
past_key_values=None,
|
||||||
|
cache_position=None,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
image_hidden_states=None,
|
image_hidden_states=None,
|
||||||
image_attention_mask=None,
|
image_attention_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
cache_position=None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
model_inputs = {}
|
||||||
|
if image_hidden_states is not None:
|
||||||
|
if self.config.use_resampler:
|
||||||
|
model_inputs["perceiver_embeddings"] = image_hidden_states
|
||||||
|
else:
|
||||||
|
model_inputs["image_encoder_embeddings"] = image_hidden_states
|
||||||
|
else:
|
||||||
|
model_inputs["pixel_values"] = pixel_values
|
||||||
|
|
||||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
if input_ids.shape[1] != cache_position.shape[0]:
|
if inputs_embeds is not None:
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]:
|
||||||
input_ids = input_ids[:, cache_position]
|
input_ids = input_ids[:, cache_position]
|
||||||
if image_attention_mask is not None:
|
if image_attention_mask is not None:
|
||||||
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
|
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
|
||||||
@ -1690,19 +1702,17 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
|
|||||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
model_inputs = {}
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
image_hidden_states = kwargs.pop("image_hidden_states", None)
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
if image_hidden_states is not None:
|
model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
|
||||||
if self.config.use_resampler:
|
|
||||||
model_inputs["perceiver_embeddings"] = image_hidden_states
|
|
||||||
else:
|
|
||||||
model_inputs["image_encoder_embeddings"] = image_hidden_states
|
|
||||||
else:
|
else:
|
||||||
model_inputs["pixel_values"] = pixel_values
|
# The clone here is for the same reason as for `position_ids`.
|
||||||
|
model_inputs.update(
|
||||||
|
{"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||||
|
)
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
"input_ids": input_ids,
|
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
|
@ -772,6 +772,12 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
|||||||
def test_custom_4d_attention_mask(self):
|
def test_custom_4d_attention_mask(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="IDEFICS has specific requirements for working with inputs embeds like passing also the ids and pixels"
|
||||||
|
)
|
||||||
|
def test_generate_from_inputs_embeds_decoder_only(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
@unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
|
||||||
def test_generate_compile_fullgraph(self):
|
def test_generate_compile_fullgraph(self):
|
||||||
pass
|
pass
|
||||||
|
@ -539,6 +539,31 @@ class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||||
|
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
# some models infer position ids/attn mask differently when input ids
|
||||||
|
# by check if pad_token let's make sure no padding is in input ids
|
||||||
|
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
||||||
|
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
||||||
|
del inputs["input_ids"]
|
||||||
|
inputs_embeds = wte(input_ids)
|
||||||
|
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
|
||||||
|
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
class Idefics2ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
|
@ -526,6 +526,31 @@ class Idefics3ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTest
|
|||||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||||
|
# overwrite because IDEFICS needs ids and embeds at the input to be not None
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||||
|
|
||||||
|
wte = model.get_input_embeddings()
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"]
|
||||||
|
# some models infer position ids/attn mask differently when input ids
|
||||||
|
# by check if pad_token let's make sure no padding is in input ids
|
||||||
|
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
||||||
|
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
||||||
|
del inputs["input_ids"]
|
||||||
|
inputs_embeds = wte(input_ids)
|
||||||
|
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)
|
||||||
|
out_embeds = model.generate(input_ids=input_ids, inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
class Idefics3ForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||||
|
@ -428,6 +428,12 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
|||||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking ipnut args because KOSMOS-2 has `generate()` overwritten"
|
||||||
|
)
|
||||||
|
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
model_name = "microsoft/kosmos-2-patch14-224"
|
model_name = "microsoft/kosmos-2-patch14-224"
|
||||||
|
@ -3000,8 +3000,11 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES):
|
if model_class.__name__ not in [
|
||||||
|
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
|
||||||
|
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
|
||||||
|
]:
|
||||||
continue
|
continue
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@ -3018,6 +3021,13 @@ class ModelTesterMixin:
|
|||||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||||
|
|
||||||
|
# VLMs can't generate with embeds and pixels at the same time. We expect the user to pass merged
|
||||||
|
# embeds already
|
||||||
|
if model_class.__name__ in get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES):
|
||||||
|
inputs.pop("pixel_values", None)
|
||||||
|
inputs.pop("pixel_values_videos", None)
|
||||||
|
inputs.pop("pixel_values_images", None)
|
||||||
|
|
||||||
wte = model.get_input_embeddings()
|
wte = model.get_input_embeddings()
|
||||||
if not self.is_encoder_decoder:
|
if not self.is_encoder_decoder:
|
||||||
input_ids = inputs["input_ids"]
|
input_ids = inputs["input_ids"]
|
||||||
|
Loading…
Reference in New Issue
Block a user