Paligemma: fix static cache test (#33941)

* fix

* not flaky anymore + style
This commit is contained in:
Raushan Turganbay 2024-10-05 09:47:37 +02:00 committed by GitHub
parent 38f9f10dd9
commit 612065efeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 8 additions and 15 deletions

View File

@ -881,9 +881,7 @@ class DummyModel(DummyPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(

View File

@ -758,9 +758,7 @@ class MyNewModel2Model(MyNewModel2PreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(

View File

@ -57,8 +57,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
is_training: bool,
token_type_ids: torch.Tensor,
is_training: bool = False,
token_type_ids: torch.Tensor = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
@ -94,7 +94,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask = torch.zeros_like(causal_mask)
causal_mask[:, :sequence_length] = 0.0
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
@ -378,7 +378,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask = torch.zeros_like(causal_mask)
causal_mask[:, :sequence_length] = 0.0
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
@ -593,7 +593,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
dtype = self.get_output_embeddings().weight.dtype
min_dtype = torch.finfo(dtype).min
is_training = token_type_ids is not None and kwargs.get("labels", None) is not None
model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
@ -604,8 +603,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
is_training=is_training,
token_type_ids=token_type_ids,
)
model_inputs["token_type_ids"] = token_type_ids

View File

@ -159,7 +159,8 @@ class PaliGemmaVisionText2TextModelTester:
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == config.image_token_index] = self.pad_token_id

View File

@ -4868,7 +4868,6 @@ class ModelTesterMixin:
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@is_flaky(max_attempts=10) # TODO @raushan: this test is VERY flaky on some VLMs, like paligemma
def test_static_cache_matches_dynamic(self):
"""
Tests that generating with static cache give almost same results as with dynamic cache.