mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix llava (#27909)
* fix llava * nits * attention_mask was forgotten * nice * :) * fixup
This commit is contained in:
parent
e0b617d192
commit
aa7ab98e72
@ -22,6 +22,7 @@ from torch import nn
|
||||
|
||||
from ... import PreTrainedModel
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_outputs import ModelOutput
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -472,14 +473,57 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, **kwargs
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
|
||||
):
|
||||
# Call `prepare_inputs_for_generation` from the LM
|
||||
model_input = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids, past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
||||
if past_key_values is not None:
|
||||
if isinstance(past_key_values, Cache):
|
||||
cache_length = past_key_values.get_seq_length()
|
||||
past_length = past_key_values.seen_tokens
|
||||
else:
|
||||
cache_length = past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Keep only the unprocessed tokens:
|
||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
||||
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
|
||||
# input)
|
||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
||||
# input_ids based on the past_length.
|
||||
elif past_length < input_ids.shape[1]:
|
||||
input_ids = input_ids[:, past_length:]
|
||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
||||
elif self.config.image_token_index in input_ids:
|
||||
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
||||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
||||
# older attention values, as their corresponding values are not part of the input.
|
||||
if cache_length < past_length and attention_mask is not None:
|
||||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
)
|
||||
model_input.update({"pixel_values": pixel_values})
|
||||
return model_input
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(self, *args, **kwargs):
|
||||
return self.language_model._reorder_cache(*args, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user