From 3cb9309024aeeeba5cf820539119f2f12aa4eac7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 19 May 2023 19:14:16 +0200 Subject: [PATCH] [`Blip`] Remove redundant shift right (#23153) * remove redundant shit right * fix failing tests * this time fix tests --- src/transformers/models/blip/modeling_blip.py | 17 ---- .../models/blip/modeling_tf_blip.py | 28 ------ tests/models/blip/test_modeling_blip.py | 83 +++++++++++++++--- tests/models/blip/test_modeling_tf_blip.py | 86 ++++++++++++++++++- 4 files changed, 154 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index e2459239ac5..9e0fc7419d4 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -1121,19 +1121,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding - # Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right - def _shift_right(self, input_ids): - pad_token_id = self.decoder_pad_token_id - - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() - shifted_input_ids[..., 0] = self.decoder_start_token_id - - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig) def forward( @@ -1215,10 +1202,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state - if labels is not None and decoder_input_ids is None: - # get decoder inputs from shifting lm labels to the right - this is used in training mode - decoder_input_ids = self._shift_right(labels) - answer_output = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/blip/modeling_tf_blip.py b/src/transformers/models/blip/modeling_tf_blip.py index 6ae7a2503cc..4ea9c7e7e5f 100644 --- a/src/transformers/models/blip/modeling_tf_blip.py +++ b/src/transformers/models/blip/modeling_tf_blip.py @@ -1335,30 +1335,6 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel): attentions=attns, ) - # Adapted from transformers.models.t5.modeling_tf_t5.TFT5PreTrainedModel._shift_right - def _shift_right(self, input_ids): - decoder_start_token_id = self.decoder_start_token_id - pad_token_id = self.decoder_pad_token_id - - if decoder_start_token_id is None or pad_token_id is None: - raise ValueError("decoder_start_token_id and pad_token_id must be defined!") - - start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) - start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation - shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) - - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids = tf.where( - shifted_input_ids == -100, - tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype), - shifted_input_ids, - ) - - # "Verify that `labels` has only positive values and -100" - tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype)) - - return shifted_input_ids - @unpack_inputs @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig) @@ -1440,10 +1416,6 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel): question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state - if labels is not None and decoder_input_ids is None: - # get decoder inputs from shifting lm labels to the right - this is used in training mode - decoder_input_ids = self._shift_right(labels) - answer_output = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 8829903c16c..7d9c6b5ba58 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -626,17 +626,73 @@ class BlipTextImageModelsModelTester: return config, inputs_dict +class BlipVQAModelTester: + def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + + self.parent = parent + self.text_model_tester = BlipTextModelTester(parent, **text_kwargs) + self.vision_model_tester = BlipVisionModelTester(parent, **vision_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values + + def get_config(self): + return BlipConfig.from_text_vision_configs( + self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64 + ) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = BlipModel(config).to(torch_device).eval() + with torch.no_grad(): + result = model(input_ids, pixel_values, attention_mask) + self.parent.assertEqual( + result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size) + ) + self.parent.assertEqual( + result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "labels": input_ids, + "decoder_input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + return config, inputs_dict + + @require_torch @require_vision -class BlipVQAModelTest(unittest.TestCase): +class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else () + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False def setUp(self): - self.model_tester = BlipModelTester(self) + self.model_tester = BlipVQAModelTester(self) def _prepare_inputs_for_vqa(self): _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() inputs_dict["labels"] = inputs_dict["input_ids"] + inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"] inputs_dict.pop("return_loss") return inputs_dict @@ -658,7 +714,7 @@ class BlipVQAModelTest(unittest.TestCase): for model_class in self.all_model_classes: model = model_class(self.model_tester.get_config()).to(torch_device) model.train() - loss = model(**self._prepare_inputs_for_vqa()).loss + loss = model(**self.model_tester.prepare_config_and_inputs_for_common()[1]).loss loss.backward() # verify the gradients are not None @@ -687,6 +743,18 @@ class BlipVQAModelTest(unittest.TestCase): f"Argument {arg} of forward function signature should include {arg}. Found {args}.", ) + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="BlipModel does not have input/output embeddings") + def test_model_common_attributes(self): + pass + @require_torch class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): @@ -886,14 +954,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): @require_torch class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( - ( - BlipForConditionalGeneration, - BlipForQuestionAnswering, - ) - if is_torch_available() - else () - ) + all_model_classes = (BlipForConditionalGeneration,) if is_torch_available() else () fx_compatible = False test_head_masking = False test_pruning = False diff --git a/tests/models/blip/test_modeling_tf_blip.py b/tests/models/blip/test_modeling_tf_blip.py index b8fd916ec13..3bb7b87edbb 100644 --- a/tests/models/blip/test_modeling_tf_blip.py +++ b/tests/models/blip/test_modeling_tf_blip.py @@ -526,17 +526,71 @@ class BlipTextImageModelsModelTester: return config, inputs_dict +class BlipVQAModelsModelTester: + def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + + self.parent = parent + self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs) + self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values + + def get_config(self): + return BlipConfig.from_text_vision_configs( + self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64 + ) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = TFBlipModel(config) + result = model(input_ids, pixel_values, attention_mask, training=False) + self.parent.assertEqual( + result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size) + ) + self.parent.assertEqual( + result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "decoder_input_ids": input_ids, + "labels": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + return config, inputs_dict + + @require_tf @require_vision -class BlipVQAModelTest(unittest.TestCase): +class TFBlipVQAModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFBlipForQuestionAnswering,) if is_tf_available() else () + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + test_onnx = False def setUp(self): - self.model_tester = TFBlipModelTester(self) + self.model_tester = BlipVQAModelsModelTester(self) def _prepare_inputs_for_vqa(self): _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() inputs_dict["labels"] = inputs_dict["input_ids"] + inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"] inputs_dict.pop("return_loss") return inputs_dict @@ -557,10 +611,34 @@ class BlipVQAModelTest(unittest.TestCase): """ for model_class in self.all_model_classes: model = model_class(self.model_tester.get_config()) - loss = model(**self._prepare_inputs_for_vqa(), training=True).loss + loss = model(**self.model_tester.prepare_config_and_inputs_for_common()[1], training=True).loss self.assertIsNotNone(loss, "Loss should not be None") + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="BlipModel does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + @unittest.skip(reason="Tested in individual model tests") + def test_compile_tf_model(self): + pass + + @unittest.skip("Model doesn't have a clean loss output.") + def test_keras_fit(self): + pass + @require_tf class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase): @@ -643,7 +721,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase): @require_tf class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase): - all_model_classes = (TFBlipForConditionalGeneration, TFBlipForQuestionAnswering) if is_tf_available() else () + all_model_classes = (TFBlipForConditionalGeneration,) if is_tf_available() else () test_head_masking = False test_pruning = False test_resize_embeddings = False