diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index f405407d7d9..209ff9703ff 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -166,6 +166,8 @@ class BridgeTowerContrastiveOutput(ModelOutput): Output type of ['BridgeTowerForContrastiveLearning'] Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`: + Image-text contrastive loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): @@ -174,24 +176,22 @@ class BridgeTowerContrastiveOutput(ModelOutput): The image embeddings obtained by applying the projection layer to the pooler_output. cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`): The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output. - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Image-text contrastive loss. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. """ + loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None text_embeds: Optional[Tuple[torch.FloatTensor]] = None image_embeds: Optional[Tuple[torch.FloatTensor]] = None cross_embeds: Optional[Tuple[torch.FloatTensor]] = None - loss: Optional[torch.FloatTensor] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None class BridgeTowerResidualAttention(nn.Module): @@ -1789,12 +1789,11 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = True, return_dict: Optional[bool] = None, - labels: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, ) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]: r""" - labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*): - Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match. - The pairs with 0 will be skipped for calculation. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. Returns: Examples: @@ -1803,14 +1802,29 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel): >>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning >>> import requests >>> from PIL import Image + >>> import torch - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - >>> texts = "An image of two cats chilling on a couch" + >>> image_urls = [ + ... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg", + ... "http://images.cocodataset.org/val2017/000000039769.jpg", + ... ] + >>> texts = ["two dogs in a car", "two cats sleeping on a couch"] + >>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls] >>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") >>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") - >>> outputs = model(**inputs, output_hidden_states=True) + + >>> inputs = processor(images, texts, padding=True, return_tensors="pt") + >>> loss = model(**inputs, return_loss=True).loss + + >>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt") + >>> loss_swapped = model(**inputs, return_loss=True).loss + + >>> print("Loss", round(loss.item(), 4)) + Loss 0.0019 + + >>> print("Loss with swapped images", round(loss_swapped.item(), 4)) + Loss with swapped images 2.126 ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1857,23 +1871,23 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel): itc_loss = None - if labels is not None: - labels = torch.arange(len(labels), device=logits.device) + if return_loss: + labels = torch.arange(len(logits), device=logits.device) text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels) text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels) image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels) itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0 if not return_dict: - output = tuple(logits) + output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:] return ((itc_loss,) + output) if itc_loss is not None else output return BridgeTowerContrastiveOutput( - attentions=outputs.attentions, - hidden_states=outputs.hidden_states, + loss=itc_loss, + logits=logits, text_embeds=text_embeds, image_embeds=image_embeds, cross_embeds=cross_embeds, - logits=logits, - loss=itc_loss, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) diff --git a/tests/models/bridgetower/test_modeling_bridgetower.py b/tests/models/bridgetower/test_modeling_bridgetower.py index 20396c8bf7b..078d66b5164 100644 --- a/tests/models/bridgetower/test_modeling_bridgetower.py +++ b/tests/models/bridgetower/test_modeling_bridgetower.py @@ -94,7 +94,7 @@ class BridgeTowerModelTester: self.num_hidden_layers = num_hidden_layers self.tie_word_embeddings = tie_word_embeddings self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder - self.vocab_size = 50265 + self.vocab_size = 99 self.num_channels = 3 self.seq_length = 4 self.num_image_features = 325 @@ -188,7 +188,7 @@ class BridgeTowerModelTester: result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask) result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, 50265)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -231,7 +231,7 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC def setUp(self): self.model_tester = BridgeTowerModelTester(self) - self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265) + self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99) def test_config(self): self.config_tester.run_common_tests() @@ -486,9 +486,9 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase): processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc") image = prepare_img() text = "a bunch of cats laying on a tower." - inputs = processor(image, text, return_tensors="pt").to(torch_device) + inputs = processor(image, text, padding=True, return_tensors="pt").to(torch_device) with torch.no_grad(): - outputs = model(**inputs, output_hidden_states=True) + outputs = model(**inputs, output_hidden_states=True, return_loss=True) # verify the logits expected_shape = torch.Size([1, 3, 512]) @@ -507,14 +507,16 @@ class BridgeTowerModelTrainingTest(unittest.TestCase): def setUp(self): self.model_tester = BridgeTowerModelTester(self) - self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265) + self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99) def _prepare_inputs_for_training(self, model_class): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if model_class == BridgeTowerForMaskedLM: inputs_dict["labels"] = inputs_dict["input_ids"] - elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning: + elif model_class == BridgeTowerForImageAndTextRetrieval: inputs_dict["labels"] = ids_tensor([1], 2) + elif model_class == BridgeTowerForContrastiveLearning: + inputs_dict["return_loss"] = True return config, inputs_dict def _get_non_used_layer_names(self, model_class):