mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add TF port of BLIP (#22090)
* Initial commit * more stash commit * Yet another stash commit * yet more stash commit * Mostly working except for docs / repo consistency * Stop importing model list from torch file * Add TF BLIP models to docs * Add auto classes * Move get_text_features and get_image_features * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blip/test_modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blip/test_modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/models/blip/test_modeling_tf_blip_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Use channels_last convolutions in TF (better performance + compatibility) * Remove _shape function * Move multi-line statement to one line in PT + TF * Specify tf.keras.layers instead of importing from it * Remove test_gradient_checkpointing and empty test_training methods * move some multi-line statements to one line * Update docstring for generate * Remove pruned heads set * Remove self.seq_len_dim * Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states * ensure original model follows config in more cases * Skip the same cross-attention tests in the PT tests - didn't realize we did it twice! * Add training args throughout the models and layers * make fixup * Fix docstring for inputs_embeds * Add docstring for is_decoder * Add docstrings to text models * Remove redundant computation * Add unpack_inputs / keras_serializable * Add modeling_tf_blip to doctests * Add config classes for keras serialization * Changes to allow model porting with pt-to-tf * Quick fix to decoder head and test tweaks * Revert an issue with masking the embeddings outputs * Allow missing keys in some equivalence tests (for unused layers) * Add tf-pt equivalence tests back in * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make fixup * Refactor invert_attention_mask out into tf_utils * Re-enable cross-tests on the PT side too --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
a515d0a77c
commit
5f3ea66bc0
@ -269,7 +269,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| BiT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| BLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| BLIP | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| BLIP-2 | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| BridgeTower | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
@ -1,4 +1,4 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
@ -93,4 +93,40 @@ The original code can be found [here](https://github.com/salesforce/BLIP).
|
||||
## BlipForQuestionAnswering
|
||||
|
||||
[[autodoc]] BlipForQuestionAnswering
|
||||
- forward
|
||||
- forward
|
||||
|
||||
## TFBlipModel
|
||||
|
||||
[[autodoc]] TFBlipModel
|
||||
- call
|
||||
- get_text_features
|
||||
- get_image_features
|
||||
|
||||
## TFBlipTextModel
|
||||
|
||||
[[autodoc]] TFBlipTextModel
|
||||
- call
|
||||
|
||||
|
||||
## TFBlipVisionModel
|
||||
|
||||
[[autodoc]] TFBlipVisionModel
|
||||
- call
|
||||
|
||||
|
||||
## TFBlipForConditionalGeneration
|
||||
|
||||
[[autodoc]] TFBlipForConditionalGeneration
|
||||
- call
|
||||
|
||||
|
||||
## TFBlipForImageTextRetrieval
|
||||
|
||||
[[autodoc]] TFBlipForImageTextRetrieval
|
||||
- call
|
||||
|
||||
|
||||
## TFBlipForQuestionAnswering
|
||||
|
||||
[[autodoc]] TFBlipForQuestionAnswering
|
||||
- call
|
@ -2903,6 +2903,18 @@ else:
|
||||
_import_structure["models.blenderbot_small"].extend(
|
||||
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
|
||||
)
|
||||
_import_structure["models.blip"].extend(
|
||||
[
|
||||
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFBlipForConditionalGeneration",
|
||||
"TFBlipForImageTextRetrieval",
|
||||
"TFBlipForQuestionAnswering",
|
||||
"TFBlipModel",
|
||||
"TFBlipPreTrainedModel",
|
||||
"TFBlipTextModel",
|
||||
"TFBlipVisionModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.camembert"].extend(
|
||||
[
|
||||
"TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -6145,6 +6157,16 @@ if TYPE_CHECKING:
|
||||
TFBlenderbotSmallModel,
|
||||
TFBlenderbotSmallPreTrainedModel,
|
||||
)
|
||||
from .models.blip import (
|
||||
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFBlipForConditionalGeneration,
|
||||
TFBlipForImageTextRetrieval,
|
||||
TFBlipForQuestionAnswering,
|
||||
TFBlipModel,
|
||||
TFBlipPreTrainedModel,
|
||||
TFBlipTextModel,
|
||||
TFBlipVisionModel,
|
||||
)
|
||||
from .models.camembert import (
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFCamembertForCausalLM,
|
||||
|
@ -196,7 +196,7 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
self._extra_commit_description = extra_commit_description
|
||||
self._override_model_class = override_model_class
|
||||
|
||||
def get_inputs(self, pt_model, config):
|
||||
def get_inputs(self, pt_model, tf_dummy_inputs, config):
|
||||
"""
|
||||
Returns the right inputs for the model, based on its signature.
|
||||
"""
|
||||
@ -255,7 +255,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
tf_input = processor(**processor_inputs, return_tensors="tf")
|
||||
|
||||
# Extra input requirements, in addition to the input modality
|
||||
if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")):
|
||||
if (
|
||||
config.is_encoder_decoder
|
||||
or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"))
|
||||
or "decoder_input_ids" in tf_dummy_inputs
|
||||
):
|
||||
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
|
||||
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
|
||||
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
|
||||
@ -306,18 +310,24 @@ class PTtoTFCommand(BaseTransformersCLICommand):
|
||||
except AttributeError:
|
||||
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")
|
||||
|
||||
# Load models and acquire a basic input compatible with the model.
|
||||
# Check the TF dummy inputs to see what keys we need in the forward pass
|
||||
tf_from_pt_model = tf_class.from_config(config)
|
||||
tf_dummy_inputs = tf_from_pt_model.dummy_inputs
|
||||
|
||||
del tf_from_pt_model # Try to keep only one model in memory at a time
|
||||
|
||||
# Load the model and get some basic inputs
|
||||
pt_model = pt_class.from_pretrained(self._local_dir)
|
||||
pt_model.eval()
|
||||
|
||||
pt_input, tf_input = self.get_inputs(pt_model, config)
|
||||
pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
|
||||
del pt_model # will no longer be used, and may have a large memory footprint
|
||||
|
||||
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
|
||||
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True)
|
||||
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False)
|
||||
|
||||
# Confirms that cross loading PT weights into TF worked.
|
||||
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
|
||||
|
@ -406,6 +406,7 @@ def unpack_inputs(func):
|
||||
func (`callable`):
|
||||
The callable function of the TensorFlow model.
|
||||
|
||||
|
||||
Returns:
|
||||
A callable that wraps the original `func` with the behavior described above.
|
||||
"""
|
||||
@ -1157,6 +1158,38 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
"""
|
||||
return cls(config, **kwargs)
|
||||
|
||||
def get_head_mask(self, head_mask: Optional[tf.Tensor], num_hidden_layers: int) -> tf.Tensor:
|
||||
"""
|
||||
Prepare the head mask if needed.
|
||||
|
||||
Args:
|
||||
head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
|
||||
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
|
||||
num_hidden_layers (`int`):
|
||||
The number of hidden layers in the model.
|
||||
|
||||
Returns:
|
||||
`tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
|
||||
`[None]` for each layer.
|
||||
"""
|
||||
if head_mask is not None:
|
||||
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
||||
else:
|
||||
head_mask = [None] * num_hidden_layers
|
||||
|
||||
return head_mask
|
||||
|
||||
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
|
||||
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
|
||||
if head_mask.shape.rank == 1:
|
||||
head_mask = head_mask[None, None, :, None, None]
|
||||
head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
|
||||
elif head_mask.shape.rank == 2:
|
||||
head_mask = head_mask[:, None, :, None, None]
|
||||
assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
||||
head_mask = tf.cast(head_mask, tf.float32) # switch to float if need + fp16 compatibility
|
||||
return head_mask
|
||||
|
||||
def eager_serving(self, inputs):
|
||||
"""
|
||||
Method used for serving the model. Intended not to be compiled with a tf.function decorator so that we can use
|
||||
|
@ -34,6 +34,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("bert", "TFBertModel"),
|
||||
("blenderbot", "TFBlenderbotModel"),
|
||||
("blenderbot-small", "TFBlenderbotSmallModel"),
|
||||
("blip", "TFBlipModel"),
|
||||
("camembert", "TFCamembertModel"),
|
||||
("clip", "TFCLIPModel"),
|
||||
("convbert", "TFConvBertModel"),
|
||||
@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Zero Shot Image Classification mapping
|
||||
("blip", "TFBlipModel"),
|
||||
("clip", "TFCLIPModel"),
|
||||
]
|
||||
)
|
||||
|
@ -13,7 +13,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -52,6 +58,23 @@ else:
|
||||
"BlipForImageTextRetrieval",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_blip"] = [
|
||||
"TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFBlipModel",
|
||||
"TFBlipPreTrainedModel",
|
||||
"TFBlipForConditionalGeneration",
|
||||
"TFBlipForQuestionAnswering",
|
||||
"TFBlipVisionModel",
|
||||
"TFBlipTextModel",
|
||||
"TFBlipForImageTextRetrieval",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_blip import BLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, BlipConfig, BlipTextConfig, BlipVisionConfig
|
||||
from .processing_blip import BlipProcessor
|
||||
@ -81,6 +104,23 @@ if TYPE_CHECKING:
|
||||
BlipVisionModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_blip import (
|
||||
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFBlipForConditionalGeneration,
|
||||
TFBlipForImageTextRetrieval,
|
||||
TFBlipForQuestionAnswering,
|
||||
TFBlipModel,
|
||||
TFBlipPreTrainedModel,
|
||||
TFBlipTextModel,
|
||||
TFBlipVisionModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -313,17 +313,12 @@ class BlipAttention(nn.Module):
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
mixed_qkv = self.qkv(hidden_states)
|
||||
mixed_qkv = (
|
||||
self.qkv(hidden_states)
|
||||
.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
query_states, key_states, value_states = (
|
||||
mixed_qkv[0],
|
||||
mixed_qkv[1],
|
||||
mixed_qkv[2],
|
||||
)
|
||||
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
||||
@ -587,9 +582,7 @@ class BlipEncoder(nn.Module):
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
@ -824,10 +817,7 @@ class BlipModel(BlipPreTrainedModel):
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
vision_outputs = self.vision_model(pixel_values=pixel_values, return_dict=return_dict)
|
||||
|
||||
pooled_output = vision_outputs[1] # pooled_output
|
||||
image_features = self.visual_projection(pooled_output)
|
||||
@ -993,6 +983,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
@ -1037,7 +1031,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
|
||||
Overrides *generate* function to be able to use the model as a conditional generator
|
||||
|
||||
Parameters:
|
||||
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
|
||||
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
|
||||
Input image to be processed
|
||||
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
|
||||
The sequence used as a prompt for the generation.
|
||||
@ -1066,9 +1060,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
|
||||
"""
|
||||
|
||||
batch_size = pixel_values.shape[0]
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
vision_outputs = self.vision_model(pixel_values=pixel_values)
|
||||
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
@ -1198,6 +1190,10 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
||||
)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
@ -1266,7 +1262,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
||||
Parameters:
|
||||
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
|
||||
The sequence used as a prompt for the generation.
|
||||
pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*:
|
||||
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
|
||||
Input image to be processed
|
||||
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
|
||||
@ -1295,9 +1291,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel):
|
||||
2
|
||||
```
|
||||
"""
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
)
|
||||
vision_outputs = self.vision_model(pixel_values=pixel_values)
|
||||
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
@ -1412,6 +1406,10 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel):
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
|
1753
src/transformers/models/blip/modeling_tf_blip.py
Normal file
1753
src/transformers/models/blip/modeling_tf_blip.py
Normal file
File diff suppressed because it is too large
Load Diff
1013
src/transformers/models/blip/modeling_tf_blip_text.py
Normal file
1013
src/transformers/models/blip/modeling_tf_blip_text.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -453,9 +453,7 @@ class Blip2Encoder(nn.Module):
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
|
@ -68,3 +68,31 @@ def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional
|
||||
# TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
|
||||
# it has the fix. After we drop the support for unfixed versions, remove this function.
|
||||
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
|
||||
|
||||
|
||||
def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Invert an attention mask (e.g., switches 0. and 1.).
|
||||
|
||||
Args:
|
||||
encoder_attention_mask (`torch.Tensor`): An attention mask.
|
||||
|
||||
Returns:
|
||||
`tf.Tensor`: The inverted attention mask.
|
||||
"""
|
||||
if not isinstance(encoder_attention_mask, tf.Tensor):
|
||||
encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs
|
||||
if encoder_attention_mask.shape.rank == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.shape.rank == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
||||
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
|
||||
# /transformer/transformer_layers.py#L270
|
||||
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
|
||||
# encoder_extended_attention_mask.transpose(-1, -2))
|
||||
encoder_extended_attention_mask = (
|
||||
tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask
|
||||
) * encoder_extended_attention_mask.dtype.min
|
||||
|
||||
return encoder_extended_attention_mask
|
||||
|
@ -556,6 +556,58 @@ class TFBlenderbotSmallPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFBlipForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFBlipForImageTextRetrieval(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFBlipForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFBlipModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFBlipPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFBlipTextModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFBlipVisionModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -342,6 +342,9 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = BlipTextModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
|
||||
|
||||
|
||||
class BlipModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
@ -524,6 +527,9 @@ class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
model = BlipModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
|
||||
|
||||
|
||||
class BlipTextRetrievalModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
|
@ -164,3 +164,6 @@ class BlipTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
for model_name in BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = BlipTextModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
super().test_pt_tf_model_equivalence(allow_missing_keys=True)
|
||||
|
824
tests/models/blip/test_modeling_tf_blip.py
Normal file
824
tests/models/blip/test_modeling_tf_blip.py
Normal file
@ -0,0 +1,824 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow Blip model. """
|
||||
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from transformers import BlipConfig, BlipTextConfig, BlipVisionConfig
|
||||
from transformers.testing_utils import require_tf, require_vision, slow
|
||||
from transformers.utils import is_tf_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
TFBlipForConditionalGeneration,
|
||||
TFBlipForImageTextRetrieval,
|
||||
TFBlipForQuestionAnswering,
|
||||
TFBlipModel,
|
||||
TFBlipTextModel,
|
||||
TFBlipVisionModel,
|
||||
)
|
||||
from transformers.models.blip.modeling_tf_blip import TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import BlipProcessor
|
||||
|
||||
|
||||
class TFBlipVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
initializer_range=1e-10,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return BlipVisionConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = TFBlipVisionModel(config=config)
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = (self.image_size, self.image_size)
|
||||
patch_size = (self.patch_size, self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFBlipVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as Blip does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (TFBlipVisionModel,) if is_tf_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlipVisionModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlipVisionConfig, has_text_modality=False, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="Blip does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="BlipVisionModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipVisionModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
try:
|
||||
model = TFBlipVisionModel.from_pretrained(model_name)
|
||||
except OSError:
|
||||
model = TFBlipVisionModel.from_pretrained(model_name, from_pt=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class TFBlipTextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
bos_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
if input_mask is not None:
|
||||
input_mask = input_mask.numpy()
|
||||
batch_size, seq_length = input_mask.shape
|
||||
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
input_mask[batch_idx, :start_index] = 1
|
||||
input_mask[batch_idx, start_index:] = 0
|
||||
input_mask = tf.convert_to_tensor(input_mask)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask
|
||||
|
||||
def get_config(self):
|
||||
return BlipTextConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask):
|
||||
model = TFBlipTextModel(config=config)
|
||||
result = model(input_ids, attention_mask=input_mask, training=False)
|
||||
result = model(input_ids, training=False)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFBlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlipTextModel,) if is_tf_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlipTextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlipTextConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Blip does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
try:
|
||||
model = TFBlipTextModel.from_pretrained(model_name)
|
||||
except OSError:
|
||||
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
|
||||
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
|
||||
|
||||
|
||||
class TFBlipModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
if text_kwargs is None:
|
||||
text_kwargs = {}
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
|
||||
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return BlipConfig.from_text_vision_configs(
|
||||
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = TFBlipModel(config)
|
||||
result = model(input_ids, pixel_values, attention_mask, training=False)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"return_loss": True,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFBlipModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlipModel,) if is_tf_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": TFBlipModel, "image-to-text": TFBlipForConditionalGeneration}
|
||||
if is_tf_available()
|
||||
else {}
|
||||
)
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlipModelTester(self)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipModel does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_load_vision_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save BlipConfig and check if we can load BlipVisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = BlipVisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save BlipConfig and check if we can load BlipTextConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
text_config = BlipTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
|
||||
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
|
||||
|
||||
|
||||
class BlipTextRetrievalModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
if text_kwargs is None:
|
||||
text_kwargs = {}
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
|
||||
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return BlipConfig.from_text_vision_configs(
|
||||
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = TFBlipModel(config)
|
||||
result = model(input_ids, pixel_values, attention_mask, training=False)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class BlipTextImageModelsModelTester:
|
||||
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
||||
if text_kwargs is None:
|
||||
text_kwargs = {}
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.text_model_tester = TFBlipTextModelTester(parent, **text_kwargs)
|
||||
self.vision_model_tester = TFBlipVisionModelTester(parent, **vision_kwargs)
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return BlipConfig.from_text_vision_configs(
|
||||
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = TFBlipModel(config)
|
||||
result = model(input_ids, pixel_values, attention_mask, training=False)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"labels": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_vision
|
||||
class BlipVQAModelTest(unittest.TestCase):
|
||||
all_model_classes = (TFBlipForQuestionAnswering,) if is_tf_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlipModelTester(self)
|
||||
|
||||
def _prepare_inputs_for_vqa(self):
|
||||
_, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
inputs_dict.pop("return_loss")
|
||||
return inputs_dict
|
||||
|
||||
def test_class_name_consistency(self):
|
||||
"""
|
||||
Tests that all VQA models have a class name that ends with "ForQuestionAnswering"
|
||||
"""
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(self.model_tester.get_config())
|
||||
self.assertTrue(
|
||||
model.__class__.__name__.endswith("ForQuestionAnswering"),
|
||||
f"Class name should end with 'ForVisualQuestionAnswering' got {model.__class__.__name__}",
|
||||
)
|
||||
|
||||
def test_training(self):
|
||||
"""
|
||||
Tests that all VQA models can be trained on a single batch
|
||||
"""
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(self.model_tester.get_config())
|
||||
loss = model(**self._prepare_inputs_for_vqa(), training=True).loss
|
||||
|
||||
self.assertIsNotNone(loss, "Loss should not be None")
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFBlipTextRetrievalModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlipForImageTextRetrieval,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BlipTextRetrievalModelTester(self)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipModel does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes[:-1]:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
model = model_class(config)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
# hardcode labels to be the same as input_ids
|
||||
inputs["labels"] = inputs["input_ids"]
|
||||
|
||||
loss = model(**inputs, training=True).loss
|
||||
self.assertTrue(loss is not None)
|
||||
|
||||
def test_load_vision_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save BlipConfig and check if we can load BlipVisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = BlipVisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save BlipConfig and check if we can load BlipTextConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
text_config = BlipTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip(reason="Tested in individual model tests")
|
||||
def test_compile_tf_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Model doesn't have a clean loss output.")
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFBlipTextImageModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlipForConditionalGeneration, TFBlipForQuestionAnswering) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BlipTextImageModelsModelTester(self)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
expected_arg_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
else:
|
||||
expected_arg_names = (
|
||||
["input_ids"] if model_class != TFBlipForConditionalGeneration else ["pixel_values"]
|
||||
)
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
@unittest.skip(reason="Tested in individual model tests")
|
||||
def test_compile_tf_model(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Has some odd input names!")
|
||||
def test_keras_fit(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipModel does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
for model_class in self.all_model_classes[:-1]:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
model = model_class(config)
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
# hardcode labels to be the same as input_ids
|
||||
inputs["labels"] = inputs["input_ids"]
|
||||
|
||||
loss = model(**inputs, training=True).loss
|
||||
self.assertIsNotNone(loss)
|
||||
|
||||
def test_load_vision_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save BlipConfig and check if we can load BlipVisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = BlipVisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save BlipConfig and check if we can load BlipTextConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
text_config = BlipTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
try:
|
||||
model = TFBlipModel.from_pretrained(model_name)
|
||||
except OSError:
|
||||
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "https://huggingface.co/hf-internal-testing/blip-test-image/resolve/main/demo.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_tf
|
||||
@slow
|
||||
class TFBlipModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_image_captioning(self):
|
||||
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", from_pt=True)
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
image = prepare_img()
|
||||
|
||||
# image only
|
||||
inputs = processor(images=image, return_tensors="tf")
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
|
||||
# Test output
|
||||
self.assertEqual(
|
||||
predictions[0].numpy().tolist(), [30522, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102]
|
||||
)
|
||||
|
||||
# image and context
|
||||
context = ["a picture of"]
|
||||
inputs = processor(images=image, text=context, return_tensors="tf")
|
||||
|
||||
predictions = model.generate(**inputs)
|
||||
|
||||
# Test output
|
||||
self.assertEqual(
|
||||
predictions[0].numpy().tolist(),
|
||||
[30522, 1037, 3861, 1997, 1037, 2450, 3564, 2006, 1996, 3509, 2007, 2014, 3899, 102],
|
||||
)
|
||||
|
||||
def test_inference_vqa(self):
|
||||
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", from_pt=True)
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
||||
|
||||
image = prepare_img()
|
||||
text = "how many dogs are in the picture?"
|
||||
inputs = processor(image, text=text, return_tensors="tf")
|
||||
out = model.generate(**inputs)
|
||||
|
||||
# Test output
|
||||
self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102])
|
||||
|
||||
def test_inference_itm(self):
|
||||
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", from_pt=True)
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
|
||||
|
||||
image = prepare_img()
|
||||
text = "A woman and her dog sitting in a beach"
|
||||
|
||||
inputs = processor(image, text, return_tensors="tf")
|
||||
|
||||
out_itm = model(**inputs)
|
||||
out = model(**inputs, use_itm_head=False, training=False)
|
||||
|
||||
expected_scores = tf.convert_to_tensor([[0.9798, 0.0202]])
|
||||
self.assertTrue(np.allclose(tf.nn.softmax(out_itm[0]).numpy(), expected_scores, rtol=1e-3, atol=1e-3))
|
||||
self.assertTrue(np.allclose(out[0], tf.convert_to_tensor([[0.5053]]), rtol=1e-3, atol=1e-3))
|
170
tests/models/blip/test_modeling_tf_blip_text.py
Normal file
170
tests/models/blip/test_modeling_tf_blip_text.py
Normal file
@ -0,0 +1,170 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow Blip model. """
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BlipTextConfig
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
from transformers.utils import is_tf_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFBlipTextModel
|
||||
from transformers.models.blip.modeling_tf_blip import TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class BlipTextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
bos_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
if input_mask is not None:
|
||||
input_mask = input_mask.numpy()
|
||||
batch_size, seq_length = input_mask.shape
|
||||
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
input_mask[batch_idx, :start_index] = 1
|
||||
input_mask[batch_idx, start_index:] = 0
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, tf.convert_to_tensor(input_mask)
|
||||
|
||||
def get_config(self):
|
||||
return BlipTextConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask):
|
||||
model = TFBlipTextModel(config=config)
|
||||
result = model(input_ids, attention_mask=input_mask, training=False)
|
||||
result = model(input_ids, training=False)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class BlipTextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFBlipTextModel,) if is_tf_available() else ()
|
||||
test_onnx = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BlipTextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=BlipTextConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Blip does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BlipTextModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
try:
|
||||
model = TFBlipTextModel.from_pretrained(model_name)
|
||||
except OSError:
|
||||
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
|
||||
super().test_pt_tf_model_equivalence(allow_missing_keys=allow_missing_keys)
|
@ -1984,7 +1984,7 @@ class ModelTesterMixin:
|
||||
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(pt_model))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -2036,8 +2036,12 @@ class ModelTesterMixin:
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
# Here requires `tf_inputs_dict` to build `tf_model`
|
||||
tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
|
||||
@ -2049,11 +2053,15 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
|
||||
|
@ -668,7 +668,7 @@ class TFModelTesterMixin:
|
||||
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(tf_model))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@ -703,8 +703,12 @@ class TFModelTesterMixin:
|
||||
tf_inputs_dict_with_labels = None
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
@ -716,11 +720,15 @@ class TFModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
@ -791,7 +799,7 @@ class TFModelTesterMixin:
|
||||
name="pixel_values",
|
||||
dtype="float32",
|
||||
)
|
||||
elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel"]:
|
||||
elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel", "TFBlipModel"]:
|
||||
inputs = {
|
||||
"input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"),
|
||||
"pixel_values": tf.keras.Input(
|
||||
@ -1792,6 +1800,8 @@ class TFModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
|
||||
if "labels" in tf_inputs_dict:
|
||||
return # This is some kinda funky decoder model that needs labels in its forward pass
|
||||
tf_inputs_dict = {
|
||||
key: val
|
||||
for key, val in tf_inputs_dict.items()
|
||||
@ -1805,7 +1815,7 @@ class TFModelTesterMixin:
|
||||
test_batch = next(iter(tf_dataset))
|
||||
if isinstance(test_batch, tf.Tensor):
|
||||
self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
|
||||
else:
|
||||
elif isinstance(test_batch, dict):
|
||||
# Assert we discarded the unwanted extra column but kept everything else
|
||||
self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
|
||||
self.assertNotIn("extra_unwanted_column", test_batch)
|
||||
|
@ -145,6 +145,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"TFSegformerDecodeHead", # Not a regular model.
|
||||
"AltRobertaModel", # Building part of bigger (tested) model.
|
||||
"BlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
|
||||
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
|
||||
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
|
||||
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
|
||||
"SpeechT5Decoder", # Building part of bigger (tested) model.
|
||||
@ -205,6 +206,12 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"BlipVisionModel",
|
||||
"BlipTextLMHeadModel",
|
||||
"BlipTextModel",
|
||||
"TFBlipForConditionalGeneration",
|
||||
"TFBlipForImageTextRetrieval",
|
||||
"TFBlipForQuestionAnswering",
|
||||
"TFBlipVisionModel",
|
||||
"TFBlipTextLMHeadModel",
|
||||
"TFBlipTextModel",
|
||||
"Swin2SRForImageSuperResolution",
|
||||
"BridgeTowerForImageAndTextRetrieval",
|
||||
"BridgeTowerForMaskedLM",
|
||||
|
@ -36,6 +36,7 @@ src/transformers/models/blenderbot/modeling_blenderbot.py
|
||||
src/transformers/models/blenderbot_small/configuration_blenderbot_small.py
|
||||
src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
|
||||
src/transformers/models/blip/modeling_blip.py
|
||||
src/transformers/models/blip/modeling_tf_blip.py
|
||||
src/transformers/models/bloom/configuration_bloom.py
|
||||
src/transformers/models/camembert/configuration_camembert.py
|
||||
src/transformers/models/canine/configuration_canine.py
|
||||
|
Loading…
Reference in New Issue
Block a user