Fix pos_mask application and update tests accordingly (#27892)

* Fix pos_mask application and update tests accordingly

* Fix style

* Adding comments

---------

Co-authored-by: Fernando Rodriguez <fernando.rodriguez@nielseniq.com>
This commit is contained in:
Fernando Rodriguez Sanchez 2024-01-05 12:36:10 +01:00 committed by GitHub
parent 03b980990a
commit 57e9c83213
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 4 deletions

View File

@ -1949,6 +1949,7 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
if mim_labels is not None:
mim_labels = mim_labels[pos_mask]
bool_masked_pos = bool_masked_pos[pos_mask]
# MMM Image Loss
if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
@ -1956,8 +1957,6 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
end_index = image_masked_embeddings.size(1) - 1
sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
if pos_mask is not None:
sequence_for_image = sequence_for_image[pos_mask]
if mim_labels is not None:
mim_labels = self._resize_to_2d(mim_labels)
bool_masked_pos = self._resize_to_2d(bool_masked_pos)
@ -1979,8 +1978,6 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
sequence_for_text = multimodal_masked_embeddings
sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
if pos_mask is not None:
sequence_for_text = sequence_for_text[pos_mask]
if mlm_labels is not None:
mlm_labels = self._resize_to_2d(mlm_labels)

View File

@ -1313,8 +1313,12 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
return_codebook_pixels=True,
return_image_mask=True,
)
# Create a clone of the input_ids tensor that will be its masked version
inputs["input_ids_masked"] = inputs["input_ids"].clone()
# Mask the tokens "a" & "cat" from the "a photo of a cat" text using the special 103 value
inputs["input_ids_masked"][0, 4:6] = 103
# MLM labels. It is a cloned version of input_ids where all values are -100 (i.e., ignored)
# except those that are masked, whose original values are stored
inputs["mlm_labels"] = inputs["input_ids"].clone()
inputs["mlm_labels"][:, :] = -100
inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
@ -1338,3 +1342,54 @@ class FlavaForPreTrainingIntegrationTest(unittest.TestCase):
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 7.0290069, places=4)
self.assertAlmostEqual(outputs.loss.item(), 11.0626, places=4)
@slow
def test_inference_with_itm_labels(self):
model_name = "facebook/flava-full"
model = FlavaForPreTraining.from_pretrained(model_name).to(torch_device)
processor = FlavaProcessor.from_pretrained(model_name)
torch.manual_seed(1)
random.seed(1)
image = prepare_img()
inputs = processor(
text=["a photo of a cat", "a photo of a dog"],
images=[image, image],
padding="max_length",
max_length=77,
return_tensors="pt",
return_codebook_pixels=True,
return_image_mask=True,
)
# Create a clone of the input_ids tensor that will be its masked version
inputs["input_ids_masked"] = inputs["input_ids"].clone()
# Mask the tokens "a" & "cat" from the "a photo of a cat" text using the special 103 value
inputs["input_ids_masked"][0, 4:6] = 103
# MLM labels. It is a cloned version of input_ids where all values are -100 (i.e., ignored)
# except those that are masked, whose original values are stored
inputs["mlm_labels"] = inputs["input_ids"].clone()
inputs["mlm_labels"][:, :] = -100
inputs["mlm_labels"][0, 4:6] = inputs["input_ids"][0, 4:6]
# Manually create the itm_labels tensor that indicates if the image-text match.
# In this case, the firs pair matches and the second does not
inputs["itm_labels"] = torch.tensor([1, 0])
inputs = inputs.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# verify the logits
self.assertEqual(
outputs.contrastive_logits_per_image.shape,
torch.Size((torch.count_nonzero(inputs["itm_labels"]).item(), inputs.input_ids.shape[0])),
)
self.assertEqual(
outputs.contrastive_logits_per_text.shape,
torch.Size((torch.count_nonzero(inputs["itm_labels"]).item(), inputs.pixel_values.shape[0])),
)
expected_logits = torch.tensor([[16.1291, 8.4033], [16.1291, 8.4033]], device=torch_device)
self.assertTrue(torch.allclose(outputs.contrastive_logits_per_image, expected_logits, atol=1e-3))
self.assertAlmostEqual(outputs.loss_info.mmm_text.item(), 1.75533199, places=4)
self.assertAlmostEqual(outputs.loss_info.mmm_image.item(), 6.89590501, places=4)
self.assertAlmostEqual(outputs.loss.item(), 9.1995, places=4)