Update BridgeTowerForContrastiveLearning (#22145)

* Use return_loss for BridgeTowerForContrastiveLearning, add example

* fix tests

* Update example in BridgeTowerForContrastiveLearning

* Update test_modeling_bridgetower.py

* update model output format

* minor update

* Update src/transformers/models/bridgetower/modeling_bridgetower.py

* make style

---------

Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com>
Co-authored-by: Tiep Le <tiep.le@intel.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Anahita Bhiwandiwalla 2023-03-15 12:54:38 -07:00 committed by GitHub
parent 42ad693b7b
commit 16121bae5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 29 deletions

View File

@ -166,6 +166,8 @@ class BridgeTowerContrastiveOutput(ModelOutput):
Output type of ['BridgeTowerForContrastiveLearning']
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`:
Image-text contrastive loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
@ -174,24 +176,22 @@ class BridgeTowerContrastiveOutput(ModelOutput):
The image embeddings obtained by applying the projection layer to the pooler_output.
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Image-text contrastive loss.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
text_embeds: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[Tuple[torch.FloatTensor]] = None
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
loss: Optional[torch.FloatTensor] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BridgeTowerResidualAttention(nn.Module):
@ -1789,12 +1789,11 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
The pairs with 0 will be skipped for calculation.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
Returns:
Examples:
@ -1803,14 +1802,29 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
>>> import requests
>>> from PIL import Image
>>> import torch
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = "An image of two cats chilling on a couch"
>>> image_urls = [
... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
... "http://images.cocodataset.org/val2017/000000039769.jpg",
... ]
>>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
>>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> outputs = model(**inputs, output_hidden_states=True)
>>> inputs = processor(images, texts, padding=True, return_tensors="pt")
>>> loss = model(**inputs, return_loss=True).loss
>>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
>>> loss_swapped = model(**inputs, return_loss=True).loss
>>> print("Loss", round(loss.item(), 4))
Loss 0.0019
>>> print("Loss with swapped images", round(loss_swapped.item(), 4))
Loss with swapped images 2.126
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -1857,23 +1871,23 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
itc_loss = None
if labels is not None:
labels = torch.arange(len(labels), device=logits.device)
if return_loss:
labels = torch.arange(len(logits), device=logits.device)
text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
if not return_dict:
output = tuple(logits)
output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:]
return ((itc_loss,) + output) if itc_loss is not None else output
return BridgeTowerContrastiveOutput(
attentions=outputs.attentions,
hidden_states=outputs.hidden_states,
loss=itc_loss,
logits=logits,
text_embeds=text_embeds,
image_embeds=image_embeds,
cross_embeds=cross_embeds,
logits=logits,
loss=itc_loss,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@ -94,7 +94,7 @@ class BridgeTowerModelTester:
self.num_hidden_layers = num_hidden_layers
self.tie_word_embeddings = tie_word_embeddings
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
self.vocab_size = 50265
self.vocab_size = 99
self.num_channels = 3
self.seq_length = 4
self.num_image_features = 325
@ -188,7 +188,7 @@ class BridgeTowerModelTester:
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, 50265))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
@ -231,7 +231,7 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def setUp(self):
self.model_tester = BridgeTowerModelTester(self)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
def test_config(self):
self.config_tester.run_common_tests()
@ -486,9 +486,9 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
image = prepare_img()
text = "a bunch of cats laying on a tower."
inputs = processor(image, text, return_tensors="pt").to(torch_device)
inputs = processor(image, text, padding=True, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
outputs = model(**inputs, output_hidden_states=True, return_loss=True)
# verify the logits
expected_shape = torch.Size([1, 3, 512])
@ -507,14 +507,16 @@ class BridgeTowerModelTrainingTest(unittest.TestCase):
def setUp(self):
self.model_tester = BridgeTowerModelTester(self)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
def _prepare_inputs_for_training(self, model_class):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if model_class == BridgeTowerForMaskedLM:
inputs_dict["labels"] = inputs_dict["input_ids"]
elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning:
elif model_class == BridgeTowerForImageAndTextRetrieval:
inputs_dict["labels"] = ids_tensor([1], 2)
elif model_class == BridgeTowerForContrastiveLearning:
inputs_dict["return_loss"] = True
return config, inputs_dict
def _get_non_used_layer_names(self, model_class):