mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Update BridgeTowerForContrastiveLearning (#22145)
* Use return_loss for BridgeTowerForContrastiveLearning, add example * fix tests * Update example in BridgeTowerForContrastiveLearning * Update test_modeling_bridgetower.py * update model output format * minor update * Update src/transformers/models/bridgetower/modeling_bridgetower.py * make style --------- Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com> Co-authored-by: Tiep Le <tiep.le@intel.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
42ad693b7b
commit
16121bae5c
@ -166,6 +166,8 @@ class BridgeTowerContrastiveOutput(ModelOutput):
|
|||||||
Output type of ['BridgeTowerForContrastiveLearning']
|
Output type of ['BridgeTowerForContrastiveLearning']
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`:
|
||||||
|
Image-text contrastive loss.
|
||||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
text_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
||||||
@ -174,24 +176,22 @@ class BridgeTowerContrastiveOutput(ModelOutput):
|
|||||||
The image embeddings obtained by applying the projection layer to the pooler_output.
|
The image embeddings obtained by applying the projection layer to the pooler_output.
|
||||||
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
cross_embeds (`torch.FloatTensor)`, *optional*, returned when model is initialized with `with_projection=True`):
|
||||||
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
|
The text-image cross-modal embeddings obtained by applying the projection layer to the pooler_output.
|
||||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
|
||||||
Image-text contrastive loss.
|
|
||||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
|
||||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
|
||||||
sequence_length)`.
|
|
||||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
|
||||||
the model at the output of each layer plus the optional initial embedding outputs.
|
the model at the output of each layer plus the optional initial embedding outputs.
|
||||||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||||
|
sequence_length)`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
text_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
text_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
image_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
image_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
cross_embeds: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
loss: Optional[torch.FloatTensor] = None
|
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
class BridgeTowerResidualAttention(nn.Module):
|
class BridgeTowerResidualAttention(nn.Module):
|
||||||
@ -1789,12 +1789,11 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = True,
|
output_hidden_states: Optional[bool] = True,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
return_loss: Optional[bool] = None,
|
||||||
) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]:
|
) -> Union[BridgeTowerContrastiveOutput, Tuple[torch.FloatTensor]]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
|
return_loss (`bool`, *optional*):
|
||||||
Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
|
Whether or not to return the contrastive loss.
|
||||||
The pairs with 0 will be skipped for calculation.
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@ -1803,14 +1802,29 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
|
|||||||
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
|
>>> from transformers import BridgeTowerProcessor, BridgeTowerForContrastiveLearning
|
||||||
>>> import requests
|
>>> import requests
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
>>> image_urls = [
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
... "https://farm4.staticflickr.com/3395/3428278415_81c3e27f15_z.jpg",
|
||||||
>>> texts = "An image of two cats chilling on a couch"
|
... "http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||||
|
... ]
|
||||||
|
>>> texts = ["two dogs in a car", "two cats sleeping on a couch"]
|
||||||
|
>>> images = [Image.open(requests.get(url, stream=True).raw) for url in image_urls]
|
||||||
|
|
||||||
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
>>> processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||||
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
>>> model = BridgeTowerForContrastiveLearning.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||||
>>> outputs = model(**inputs, output_hidden_states=True)
|
|
||||||
|
>>> inputs = processor(images, texts, padding=True, return_tensors="pt")
|
||||||
|
>>> loss = model(**inputs, return_loss=True).loss
|
||||||
|
|
||||||
|
>>> inputs = processor(images, texts[::-1], padding=True, return_tensors="pt")
|
||||||
|
>>> loss_swapped = model(**inputs, return_loss=True).loss
|
||||||
|
|
||||||
|
>>> print("Loss", round(loss.item(), 4))
|
||||||
|
Loss 0.0019
|
||||||
|
|
||||||
|
>>> print("Loss with swapped images", round(loss_swapped.item(), 4))
|
||||||
|
Loss with swapped images 2.126
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
@ -1857,23 +1871,23 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
|
|||||||
|
|
||||||
itc_loss = None
|
itc_loss = None
|
||||||
|
|
||||||
if labels is not None:
|
if return_loss:
|
||||||
labels = torch.arange(len(labels), device=logits.device)
|
labels = torch.arange(len(logits), device=logits.device)
|
||||||
text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
|
text_to_image_loss = nn.functional.cross_entropy(logits_text_to_image, labels)
|
||||||
text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
|
text_to_cross_loss = nn.functional.cross_entropy(logits_text_to_cross, labels)
|
||||||
image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
|
image_to_cross_loss = nn.functional.cross_entropy(logits_image_to_cross, labels)
|
||||||
itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
|
itc_loss = (text_to_image_loss + text_to_cross_loss + image_to_cross_loss) / 3.0
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = tuple(logits)
|
output = (logits, text_embeds, image_embeds, cross_embeds) + outputs[3:]
|
||||||
return ((itc_loss,) + output) if itc_loss is not None else output
|
return ((itc_loss,) + output) if itc_loss is not None else output
|
||||||
|
|
||||||
return BridgeTowerContrastiveOutput(
|
return BridgeTowerContrastiveOutput(
|
||||||
attentions=outputs.attentions,
|
loss=itc_loss,
|
||||||
hidden_states=outputs.hidden_states,
|
logits=logits,
|
||||||
text_embeds=text_embeds,
|
text_embeds=text_embeds,
|
||||||
image_embeds=image_embeds,
|
image_embeds=image_embeds,
|
||||||
cross_embeds=cross_embeds,
|
cross_embeds=cross_embeds,
|
||||||
logits=logits,
|
hidden_states=outputs.hidden_states,
|
||||||
loss=itc_loss,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
@ -94,7 +94,7 @@ class BridgeTowerModelTester:
|
|||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
|
self.init_layernorm_from_vision_encoder = init_layernorm_from_vision_encoder
|
||||||
self.vocab_size = 50265
|
self.vocab_size = 99
|
||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.seq_length = 4
|
self.seq_length = 4
|
||||||
self.num_image_features = 325
|
self.num_image_features = 325
|
||||||
@ -188,7 +188,7 @@ class BridgeTowerModelTester:
|
|||||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
|
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||||
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
|
result = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
|
||||||
|
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, 50265))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
@ -231,7 +231,7 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BridgeTowerModelTester(self)
|
self.model_tester = BridgeTowerModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
|
||||||
|
|
||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
@ -486,9 +486,9 @@ class BridgeTowerModelIntegrationTest(unittest.TestCase):
|
|||||||
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
processor = BridgeTowerProcessor.from_pretrained("BridgeTower/bridgetower-large-itm-mlm-itc")
|
||||||
image = prepare_img()
|
image = prepare_img()
|
||||||
text = "a bunch of cats laying on a tower."
|
text = "a bunch of cats laying on a tower."
|
||||||
inputs = processor(image, text, return_tensors="pt").to(torch_device)
|
inputs = processor(image, text, padding=True, return_tensors="pt").to(torch_device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs, output_hidden_states=True)
|
outputs = model(**inputs, output_hidden_states=True, return_loss=True)
|
||||||
|
|
||||||
# verify the logits
|
# verify the logits
|
||||||
expected_shape = torch.Size([1, 3, 512])
|
expected_shape = torch.Size([1, 3, 512])
|
||||||
@ -507,14 +507,16 @@ class BridgeTowerModelTrainingTest(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BridgeTowerModelTester(self)
|
self.model_tester = BridgeTowerModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=50265)
|
self.config_tester = ConfigTester(self, config_class=BridgeTowerConfig, hidden_size=37, vocab_size=99)
|
||||||
|
|
||||||
def _prepare_inputs_for_training(self, model_class):
|
def _prepare_inputs_for_training(self, model_class):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
if model_class == BridgeTowerForMaskedLM:
|
if model_class == BridgeTowerForMaskedLM:
|
||||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||||
elif model_class == BridgeTowerForImageAndTextRetrieval or model_class == BridgeTowerForContrastiveLearning:
|
elif model_class == BridgeTowerForImageAndTextRetrieval:
|
||||||
inputs_dict["labels"] = ids_tensor([1], 2)
|
inputs_dict["labels"] = ids_tensor([1], 2)
|
||||||
|
elif model_class == BridgeTowerForContrastiveLearning:
|
||||||
|
inputs_dict["return_loss"] = True
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
def _get_non_used_layer_names(self, model_class):
|
def _get_non_used_layer_names(self, model_class):
|
||||||
|
Loading…
Reference in New Issue
Block a user