mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[Blip
] Remove redundant shift right (#23153)
* remove redundant shit right * fix failing tests * this time fix tests
This commit is contained in:
parent
847e5691a6
commit
3cb9309024
@ -1121,19 +1121,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
# Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
|
||||
def _shift_right(self, input_ids):
|
||||
pad_token_id = self.decoder_pad_token_id
|
||||
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
||||
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
||||
def forward(
|
||||
@ -1215,10 +1202,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
||||
|
||||
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
||||
|
||||
if labels is not None and decoder_input_ids is None:
|
||||
# get decoder inputs from shifting lm labels to the right - this is used in training mode
|
||||
decoder_input_ids = self._shift_right(labels)
|
||||
|
||||
answer_output = self.text_decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
|
@ -1335,30 +1335,6 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
|
||||
attentions=attns,
|
||||
)
|
||||
|
||||
# Adapted from transformers.models.t5.modeling_tf_t5.TFT5PreTrainedModel._shift_right
|
||||
def _shift_right(self, input_ids):
|
||||
decoder_start_token_id = self.decoder_start_token_id
|
||||
pad_token_id = self.decoder_pad_token_id
|
||||
|
||||
if decoder_start_token_id is None or pad_token_id is None:
|
||||
raise ValueError("decoder_start_token_id and pad_token_id must be defined!")
|
||||
|
||||
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
||||
start_tokens = tf.cast(start_tokens, input_ids.dtype) # Ensure compatible dtypes for concatenation
|
||||
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
||||
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = tf.where(
|
||||
shifted_input_ids == -100,
|
||||
tf.cast(tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids.dtype),
|
||||
shifted_input_ids,
|
||||
)
|
||||
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype))
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=TFBlipTextVisionModelOutput, config_class=BlipVisionConfig)
|
||||
@ -1440,10 +1416,6 @@ class TFBlipForQuestionAnswering(TFBlipPreTrainedModel):
|
||||
|
||||
question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state
|
||||
|
||||
if labels is not None and decoder_input_ids is None:
|
||||
# get decoder inputs from shifting lm labels to the right - this is used in training mode
|
||||
decoder_input_ids = self._shift_right(labels)
|
||||
|
||||
answer_output = self.text_decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
|
@ -626,17 +626,73 @@ class BlipTextImageModelsModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class BlipVQAModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
if text_kwargs is None:
|
||||
text_kwargs = {}
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.text_model_tester = BlipTextModelTester(parent, **text_kwargs)
|
||||
self.vision_model_tester = BlipVisionModelTester(parent, **vision_kwargs)
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return BlipConfig.from_text_vision_configs(
|
||||
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = BlipModel(config).to(torch_device).eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_ids, pixel_values, attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"labels": input_ids,
|
||||
"decoder_input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class BlipVQAModelTest(unittest.TestCase):
|
||||
class BlipVQAModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BlipModelTester(self)
|
||||
self.model_tester = BlipVQAModelTester(self)
|
||||
|
||||
def _prepare_inputs_for_vqa(self):
|
||||
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
|
||||
inputs_dict.pop("return_loss")
|
||||
return inputs_dict
|
||||
|
||||
@ -658,7 +714,7 @@ class BlipVQAModelTest(unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(self.model_tester.get_config()).to(torch_device)
|
||||
model.train()
|
||||
loss = model(**self._prepare_inputs_for_vqa()).loss
|
||||
loss = model(**self.model_tester.prepare_config_and_inputs_for_common()[1]).loss
|
||||
loss.backward()
|
||||
|
||||
# verify the gradients are not None
|
||||
@ -687,6 +743,18 @@ class BlipVQAModelTest(unittest.TestCase):
|
||||
f"Argument {arg} of forward function signature should include {arg}. Found {args}.",
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipModel does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@ -886,14 +954,7 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
BlipForConditionalGeneration,
|
||||
BlipForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_model_classes = (BlipForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
|
@ -526,17 +526,71 @@ class BlipTextImageModelsModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class BlipVQAModelsModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
if text_kwargs is None:
|
||||
text_kwargs = {}
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
|
||||
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return BlipConfig.from_text_vision_configs(
|
||||
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = TFBlipModel(config)
|
||||
result = model(input_ids, pixel_values, attention_mask, training=False)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"decoder_input_ids": input_ids,
|
||||
"labels": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_vision
|
||||
class BlipVQAModelTest(unittest.TestCase):
|
||||
class TFBlipVQAModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlipForQuestionAnswering,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlipModelTester(self)
|
||||
self.model_tester = BlipVQAModelsModelTester(self)
|
||||
|
||||
def _prepare_inputs_for_vqa(self):
|
||||
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
|
||||
inputs_dict.pop("return_loss")
|
||||
return inputs_dict
|
||||
|
||||
@ -557,10 +611,34 @@ class BlipVQAModelTest(unittest.TestCase):
|
||||
"""
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(self.model_tester.get_config())
|
||||
loss = model(**self._prepare_inputs_for_vqa(), training=True).loss
|
||||
loss = model(**self.model_tester.prepare_config_and_inputs_for_common()[1], training=True).loss
|
||||
|
||||
self.assertIsNotNone(loss, "Loss should not be None")
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipModel does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Tested in individual model tests")
|
||||
def test_compile_tf_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Model doesn't have a clean loss output.")
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@ -643,7 +721,7 @@ class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_tf
|
||||
class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlipForConditionalGeneration, TFBlipForQuestionAnswering) if is_tf_available() else ()
|
||||
all_model_classes = (TFBlipForConditionalGeneration,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
Loading…
Reference in New Issue
Block a user