Update BridgeTowerForContrastiveLearning (#22145)

* Use return_loss for BridgeTowerForContrastiveLearning, add example

* fix tests

* Update example in BridgeTowerForContrastiveLearning

* Update test_modeling_bridgetower.py

* update model output format

* minor update

* Update src/transformers/models/bridgetower/modeling_bridgetower.py

* make style

---------

Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com>
Co-authored-by: Tiep Le <tiep.le@intel.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Anahita Bhiwandiwalla 2023-03-15 12:54:38 -07:00 committed by GitHub
parent 42ad693b7b
commit 16121bae5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 29 deletions

View File

@ -166,6 +166,8 @@ class BridgeTowerContrastiveOutput(ModelOutput):
Output type of ['BridgeTowerForContrastiveLearning'] Output type of ['BridgeTowerForContrastiveLearning']
Args: Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`:
Image-text contrastive loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
@ -174,24 +176,22 @@ class BridgeTowerContrastiveOutput(ModelOutput):
The image embeddings obtained by applying the projection layer to the pooler_output. The image embeddings obtained by applying the projection layer to the pooler_output.
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output. The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Image-text contrastive loss.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
the model at the output of each layer plus the optional initial embedding outputs. the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
""" """
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
text_embeds: Optional[Tuple[torch.FloatTensor]] = None text_embeds: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[Tuple[torch.FloatTensor]] = None image_embeds: Optional[Tuple[torch.FloatTensor]] = None
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
loss: Optional[torch.FloatTensor] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BridgeTowerResidualAttention(nn.Module): class BridgeTowerResidualAttention(nn.Module):
@ -1789,12 +1789,11 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = True, output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None,
) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]: ) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): return_loss (`bool`, *optional*):
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. Whether or not to return the contrastive loss.
The pairs with 0 will be skipped for calculation.
Returns: Returns:
Examples: Examples:
@ -1803,14 +1802,29 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
>>> import requests >>> import requests
>>> from PIL import Image >>> from PIL import Image
>>> import torch
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image_urls = [
>>> image = Image.open(requests.get(url, stream=True).raw) ... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
>>> texts = "An image of two cats chilling on a couch" ... "http://images.cocodataset.org/val2017/000000039769.jpg",
... ]
>>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
>>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
>>> outputs = model(**inputs, output_hidden_states=True)
>>> inputs = processor(images, texts, padding=True, return_tensors="pt")
>>> loss = model(**inputs, return_loss=True).loss
>>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
>>> loss_swapped = model(**inputs, return_loss=True).loss
>>> print("Loss", round(loss.item(), 4))
Loss 0.0019
>>> print("Loss with swapped images", round(loss_swapped.item(), 4))
Loss with swapped images 2.126
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -1857,23 +1871,23 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
itc_loss = None itc_loss = None
if labels is not None: if return_loss:
labels = torch.arange(len(labels), device=logits.device) labels = torch.arange(len(logits), device=logits.device)
text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels) text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels) text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels) image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0 itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
if not return_dict: if not return_dict:
output = tuple(logits) output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:]
return ((itc_loss,) + output) if itc_loss is not None else output return ((itc_loss,) + output) if itc_loss is not None else output
return BridgeTowerContrastiveOutput( return BridgeTowerContrastiveOutput(
attentions=outputs.attentions, loss=itc_loss,
hidden_states=outputs.hidden_states, logits=logits,
text_embeds=text_embeds, text_embeds=text_embeds,
image_embeds=image_embeds, image_embeds=image_embeds,
cross_embeds=cross_embeds, cross_embeds=cross_embeds,
logits=logits, hidden_states=outputs.hidden_states,
loss=itc_loss, attentions=outputs.attentions,
) )

View File

@ -94,7 +94,7 @@ class BridgeTowerModelTester:
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.tie_word_embeddings = tie_word_embeddings self.tie_word_embeddings = tie_word_embeddings
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
self.vocab_size = 50265 self.vocab_size = 99
self.num_channels = 3 self.num_channels = 3
self.seq_length = 4 self.seq_length = 4
self.num_image_features = 325 self.num_image_features = 325
@ -188,7 +188,7 @@ class BridgeTowerModelTester:
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask) result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values) result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, 50265))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
@ -231,7 +231,7 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def setUp(self): def setUp(self):
self.model_tester = BridgeTowerModelTester(self) self.model_tester = BridgeTowerModelTester(self)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265) self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@ -486,9 +486,9 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
image = prepare_img() image = prepare_img()
text = "a bunch of cats laying on a tower." text = "a bunch of cats laying on a tower."
inputs = processor(image, text, return_tensors="pt").to(torch_device) inputs = processor(image, text, padding=True, return_tensors="pt").to(torch_device)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True) outputs = model(**inputs, output_hidden_states=True, return_loss=True)
# verify the logits # verify the logits
expected_shape = torch.Size([1, 3, 512]) expected_shape = torch.Size([1, 3, 512])
@ -507,14 +507,16 @@ class BridgeTowerModelTrainingTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.model_tester = BridgeTowerModelTester(self) self.model_tester = BridgeTowerModelTester(self)
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265) self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
def _prepare_inputs_for_training(self, model_class): def _prepare_inputs_for_training(self, model_class):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if model_class == BridgeTowerForMaskedLM: if model_class == BridgeTowerForMaskedLM:
inputs_dict["labels"] = inputs_dict["input_ids"] inputs_dict["labels"] = inputs_dict["input_ids"]
elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning: elif model_class == BridgeTowerForImageAndTextRetrieval:
inputs_dict["labels"] = ids_tensor([1], 2) inputs_dict["labels"] = ids_tensor([1], 2)
elif model_class == BridgeTowerForContrastiveLearning:
inputs_dict["return_loss"] = True
return config, inputs_dict return config, inputs_dict
def _get_non_used_layer_names(self, model_class): def _get_non_used_layer_names(self, model_class):