From 612065efeba61b9ee9f45e80aa4c1368d6d43934 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Sat, 5 Oct 2024 09:47:37 +0200 Subject: [PATCH] Paligemma: fix static cache test (#33941) * fix * not flaky anymore + style --- examples/modular-transformers/modeling_dummy.py | 4 +--- .../modular-transformers/modeling_my_new_model2.py | 4 +--- .../models/paligemma/modeling_paligemma.py | 11 ++++------- tests/models/paligemma/test_modeling_paligemma.py | 3 ++- tests/test_modeling_common.py | 1 - 5 files changed, 8 insertions(+), 15 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index c57d41785f6..51349ecf4ec 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -881,9 +881,7 @@ class DummyModel(DummyPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index a1fabf5d8c5..49cdd274162 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -758,9 +758,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index cba66aa3a81..5e695f3387d 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -57,8 +57,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position( min_dtype: float, cache_position: torch.Tensor, batch_size: int, - is_training: bool, - token_type_ids: torch.Tensor, + is_training: bool = False, + token_type_ids: torch.Tensor = None, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -94,7 +94,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if is_training: causal_mask = torch.triu(causal_mask, diagonal=1) else: - causal_mask = torch.zeros_like(causal_mask) + causal_mask[:, :sequence_length] = 0.0 causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) @@ -378,7 +378,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi if is_training: causal_mask = torch.triu(causal_mask, diagonal=1) else: - causal_mask = torch.zeros_like(causal_mask) + causal_mask[:, :sequence_length] = 0.0 causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1) @@ -593,7 +593,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi dtype = self.get_output_embeddings().weight.dtype min_dtype = torch.finfo(dtype).min - is_training = token_type_ids is not None and kwargs.get("labels", None) is not None model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, @@ -604,8 +603,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, - is_training=is_training, - token_type_ids=token_type_ids, ) model_inputs["token_type_ids"] = token_type_ids diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index d954fa2a0f5..7d72226e41b 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -159,7 +159,8 @@ class PaliGemmaVisionText2TextModelTester: config_and_inputs = self.prepare_config_and_inputs() config, pixel_values = config_and_inputs input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 - attention_mask = input_ids.ne(1).to(torch_device) + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) + # set the 16 first tokens to be image, and ensure that no other tokens are image tokens # do not change this unless you modified image size or patch size input_ids[input_ids == config.image_token_index] = self.pad_token_id diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1f5b232f0db..6d40359f917 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4868,7 +4868,6 @@ class ModelTesterMixin: normalized_1 = F.softmax(out_shared_prefix_last_tokens) torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) - @is_flaky(max_attempts=10) # TODO @raushan: this test is VERY flaky on some VLMs, like paligemma def test_static_cache_matches_dynamic(self): """ Tests that generating with static cache give almost same results as with dynamic cache.