mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Update BridgeTowerForContrastiveLearning (#22145)
* Use return_loss for BridgeTowerForContrastiveLearning, add example * fix tests * Update example in BridgeTowerForContrastiveLearning * Update test_modeling_bridgetower.py * update model output format * minor update * Update src/transformers/models/bridgetower/modeling_bridgetower.py * make style --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le <tiep.le@intel.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
42ad693b7b
commit
16121bae5c
@ -166,6 +166,8 @@ class BridgeTowerContrastiveOutput(ModelOutput):
|
||||
Output type of ['BridgeTowerForContrastiveLearning']
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`:
|
||||
Image-text contrastive loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
||||
@ -174,24 +176,22 @@ class BridgeTowerContrastiveOutput(ModelOutput):
|
||||
The image embeddings obtained by applying the projection layer to the pooler_output.
|
||||
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
||||
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Image-text contrastive loss.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
|
||||
the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
text_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||
image_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class BridgeTowerResidualAttention(nn.Module):
|
||||
@ -1789,12 +1789,11 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = True,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_loss: Optional[bool] = None,
|
||||
) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
|
||||
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
|
||||
The pairs with 0 will be skipped for calculation.
|
||||
return_loss (`bool`, *optional*):
|
||||
Whether or not to return the contrastive loss.
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
@ -1803,14 +1802,29 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
|
||||
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> texts = "An image of two cats chilling on a couch"
|
||||
>>> image_urls = [
|
||||
... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
|
||||
... "http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
... ]
|
||||
>>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
|
||||
>>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
|
||||
|
||||
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||
>>> outputs = model(**inputs, output_hidden_states=True)
|
||||
|
||||
>>> inputs = processor(images, texts, padding=True, return_tensors="pt")
|
||||
>>> loss = model(**inputs, return_loss=True).loss
|
||||
|
||||
>>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
|
||||
>>> loss_swapped = model(**inputs, return_loss=True).loss
|
||||
|
||||
>>> print("Loss", round(loss.item(), 4))
|
||||
Loss 0.0019
|
||||
|
||||
>>> print("Loss with swapped images", round(loss_swapped.item(), 4))
|
||||
Loss with swapped images 2.126
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -1857,23 +1871,23 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
|
||||
|
||||
itc_loss = None
|
||||
|
||||
if labels is not None:
|
||||
labels = torch.arange(len(labels), device=logits.device)
|
||||
if return_loss:
|
||||
labels = torch.arange(len(logits), device=logits.device)
|
||||
text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
|
||||
text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
|
||||
image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
|
||||
itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
|
||||
|
||||
if not return_dict:
|
||||
output = tuple(logits)
|
||||
output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:]
|
||||
return ((itc_loss,) + output) if itc_loss is not None else output
|
||||
|
||||
return BridgeTowerContrastiveOutput(
|
||||
attentions=outputs.attentions,
|
||||
hidden_states=outputs.hidden_states,
|
||||
loss=itc_loss,
|
||||
logits=logits,
|
||||
text_embeds=text_embeds,
|
||||
image_embeds=image_embeds,
|
||||
cross_embeds=cross_embeds,
|
||||
logits=logits,
|
||||
loss=itc_loss,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -94,7 +94,7 @@ class BridgeTowerModelTester:
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
|
||||
self.vocab_size = 50265
|
||||
self.vocab_size = 99
|
||||
self.num_channels = 3
|
||||
self.seq_length = 4
|
||||
self.num_image_features = 325
|
||||
@ -188,7 +188,7 @@ class BridgeTowerModelTester:
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
|
||||
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, 50265))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
@ -231,7 +231,7 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BridgeTowerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
@ -486,9 +486,9 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
||||
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||
image = prepare_img()
|
||||
text = "a bunch of cats laying on a tower."
|
||||
inputs = processor(image, text, return_tensors="pt").to(torch_device)
|
||||
inputs = processor(image, text, padding=True, return_tensors="pt").to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
outputs = model(**inputs, output_hidden_states=True, return_loss=True)
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size([1, 3, 512])
|
||||
@ -507,14 +507,16 @@ class BridgeTowerModelTrainingTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BridgeTowerModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
|
||||
|
||||
def _prepare_inputs_for_training(self, model_class):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if model_class == BridgeTowerForMaskedLM:
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning:
|
||||
elif model_class == BridgeTowerForImageAndTextRetrieval:
|
||||
inputs_dict["labels"] = ids_tensor([1], 2)
|
||||
elif model_class == BridgeTowerForContrastiveLearning:
|
||||
inputs_dict["return_loss"] = True
|
||||
return config, inputs_dict
|
||||
|
||||
def _get_non_used_layer_names(self, model_class):
|
||||
|
Loading…
Reference in New Issue
Block a user