mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Paligemma: fix static cache test (#33941)
* fix * not flaky anymore + style
This commit is contained in:
parent
38f9f10dd9
commit
612065efeb
@ -881,9 +881,7 @@ class DummyModel(DummyPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You must specify exactly one of input_ids or inputs_embeds"
|
||||
)
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
|
@ -758,9 +758,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You must specify exactly one of input_ids or inputs_embeds"
|
||||
)
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
|
@ -57,8 +57,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_training: bool,
|
||||
token_type_ids: torch.Tensor,
|
||||
is_training: bool = False,
|
||||
token_type_ids: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
@ -94,7 +94,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask = torch.zeros_like(causal_mask)
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
@ -378,7 +378,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask = torch.zeros_like(causal_mask)
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||
@ -593,7 +593,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
|
||||
dtype = self.get_output_embeddings().weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
is_training = token_type_ids is not None and kwargs.get("labels", None) is not None
|
||||
|
||||
model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
@ -604,8 +603,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
is_training=is_training,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
model_inputs["token_type_ids"] = token_type_ids
|
||||
|
@ -159,7 +159,8 @@ class PaliGemmaVisionText2TextModelTester:
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
|
||||
|
||||
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
|
||||
# do not change this unless you modified image size or patch size
|
||||
input_ids[input_ids == config.image_token_index] = self.pad_token_id
|
||||
|
@ -4868,7 +4868,6 @@ class ModelTesterMixin:
|
||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
|
||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||
|
||||
@is_flaky(max_attempts=10) # TODO @raushan: this test is VERY flaky on some VLMs, like paligemma
|
||||
def test_static_cache_matches_dynamic(self):
|
||||
"""
|
||||
Tests that generating with static cache give almost same results as with dynamic cache.
|
||||
|
Loading…
Reference in New Issue
Block a user