mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
Fix vipllava for generation (#29874)
* fix vipllava generation * consistent llava code * revert llava tests changes
This commit is contained in:
parent
240e10626b
commit
cc75f1ac73
@ -569,10 +569,11 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||||
|
|
||||||
# Get the target length
|
# Get the target length
|
||||||
target_seqlen = first_layer_past_key_value.shape[-1] + 1
|
target_length = input_ids.shape[1]
|
||||||
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
|
|
||||||
extended_attention_mask = torch.ones(
|
extended_attention_mask = torch.ones(
|
||||||
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
|
(attention_mask.shape[0], past_length),
|
||||||
dtype=attention_mask.dtype,
|
dtype=attention_mask.dtype,
|
||||||
device=attention_mask.device,
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
@ -587,7 +588,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
|
|||||||
# Zero-out the places where we don't need to attend
|
# Zero-out the places where we don't need to attend
|
||||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||||
|
|
||||||
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
|
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
|
||||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
|
@ -441,10 +441,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
|
|||||||
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
|
||||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||||
# that are set to 0
|
# that are set to 0
|
||||||
first_layer_past_key_value = past_key_values[0][0][:, 0, :, :]
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||||
|
|
||||||
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||||
|
|
||||||
target_length = input_ids.shape[1]
|
target_length = input_ids.shape[1]
|
||||||
past_length = first_layer_past_key_value.shape[-1]
|
past_length = first_layer_past_key_value.shape[-1]
|
||||||
|
@ -423,7 +423,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
output = model(**inputs)
|
output = model(**inputs)
|
||||||
|
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
[[-4.7695, -4.5664, -0.2786], [-10.6172, -10.8906, -2.5234], [-6.7344, -7.2422, -0.6758]],
|
[[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5254], [-6.7383, -7.2461, -0.6787]],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user