mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix pos_mask application and update tests accordingly (#27892)
* Fix pos_mask application and update tests accordingly * Fix style * Adding comments --------- Co-authored-by: Fernando Rodriguez <fernando.rodriguez@nielseniq.com>
This commit is contained in:
parent
03b980990a
commit
57e9c83213
@ -1949,6 +1949,7 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
|
||||
|
||||
if mim_labels is not None:
|
||||
mim_labels = mim_labels[pos_mask]
|
||||
bool_masked_pos = bool_masked_pos[pos_mask]
|
||||
|
||||
# MMM Image Loss
|
||||
if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
|
||||
@ -1956,8 +1957,6 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
|
||||
end_index = image_masked_embeddings.size(1) - 1
|
||||
sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
|
||||
|
||||
if pos_mask is not None:
|
||||
sequence_for_image = sequence_for_image[pos_mask]
|
||||
if mim_labels is not None:
|
||||
mim_labels = self._resize_to_2d(mim_labels)
|
||||
bool_masked_pos = self._resize_to_2d(bool_masked_pos)
|
||||
@ -1979,8 +1978,6 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
|
||||
if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
|
||||
sequence_for_text = multimodal_masked_embeddings
|
||||
sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
|
||||
if pos_mask is not None:
|
||||
sequence_for_text = sequence_for_text[pos_mask]
|
||||
|
||||
if mlm_labels is not None:
|
||||
mlm_labels = self._resize_to_2d(mlm_labels)
|
||||
|
@ -1313,8 +1313,12 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
|
||||
return_codebook_pixels=True,
|
||||
return_image_mask=True,
|
||||
)
|
||||
# Create a clone of the input_ids tensor that will be its masked version
|
||||
inputs["input_ids_masked"] = inputs["input_ids"].clone()
|
||||
# Mask the tokens "a" & "cat" from the "a photo of a cat" text using the special 103 value
|
||||
inputs["input_ids_masked"][0, 4:6] = 103
|
||||
# MLM labels. It is a cloned version of input_ids where all values are -100 (i.e., ignored)
|
||||
# except those that are masked, whose original values are stored
|
||||
inputs["mlm_labels"] = inputs["input_ids"].clone()
|
||||
inputs["mlm_labels"][:, :] = -100
|
||||
inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
|
||||
@ -1338,3 +1342,54 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
|
||||
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
|
||||
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
|
||||
|
||||
@slow
|
||||
def test_inference_with_itm_labels(self):
|
||||
model_name = "facebook/flava-full"
|
||||
model = FlavaForPreTraining.from_pretrained(model_name).to(torch_device)
|
||||
processor = FlavaProcessor.from_pretrained(model_name)
|
||||
torch.manual_seed(1)
|
||||
random.seed(1)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = processor(
|
||||
text=["a photo of a cat", "a photo of a dog"],
|
||||
images=[image, image],
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
return_tensors="pt",
|
||||
return_codebook_pixels=True,
|
||||
return_image_mask=True,
|
||||
)
|
||||
# Create a clone of the input_ids tensor that will be its masked version
|
||||
inputs["input_ids_masked"] = inputs["input_ids"].clone()
|
||||
# Mask the tokens "a" & "cat" from the "a photo of a cat" text using the special 103 value
|
||||
inputs["input_ids_masked"][0, 4:6] = 103
|
||||
# MLM labels. It is a cloned version of input_ids where all values are -100 (i.e., ignored)
|
||||
# except those that are masked, whose original values are stored
|
||||
inputs["mlm_labels"] = inputs["input_ids"].clone()
|
||||
inputs["mlm_labels"][:, :] = -100
|
||||
inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
|
||||
# Manually create the itm_labels tensor that indicates if the image-text match.
|
||||
# In this case, the firs pair matches and the second does not
|
||||
inputs["itm_labels"] = torch.tensor([1, 0])
|
||||
inputs = inputs.to(torch_device)
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
self.assertEqual(
|
||||
outputs.contrastive_logits_per_image.shape,
|
||||
torch.Size((torch.count_nonzero(inputs["itm_labels"]).item(), inputs.input_ids.shape[0])),
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs.contrastive_logits_per_text.shape,
|
||||
torch.Size((torch.count_nonzero(inputs["itm_labels"]).item(), inputs.pixel_values.shape[0])),
|
||||
)
|
||||
|
||||
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
|
||||
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
|
||||
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.89590501, places=4)
|
||||
self.assertAlmostEqual(outputs.loss.item(), 9.1995, places=4)
|
||||
|
Loading…
Reference in New Issue
Block a user