mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Expose blip2qformer (#37254)
* Expose blip2qformer * Add missing args to blip2 config
This commit is contained in:
parent
2da82e432d
commit
2515a5a290
@ -54,6 +54,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("blenderbot-small", "BlenderbotSmallConfig"),
|
||||
("blip", "BlipConfig"),
|
||||
("blip-2", "Blip2Config"),
|
||||
("blip_2_qformer", "Blip2QFormerConfig"),
|
||||
("bloom", "BloomConfig"),
|
||||
("bridgetower", "BridgeTowerConfig"),
|
||||
("bros", "BrosConfig"),
|
||||
@ -391,6 +392,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("blenderbot-small", "BlenderbotSmall"),
|
||||
("blip", "BLIP"),
|
||||
("blip-2", "BLIP-2"),
|
||||
("blip_2_qformer", "BLIP-2 QFormer"),
|
||||
("bloom", "BLOOM"),
|
||||
("bort", "BORT"),
|
||||
("bridgetower", "BridgeTower"),
|
||||
@ -781,6 +783,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
|
||||
("granitevision", "llava_next"),
|
||||
("sam_vision_model", "sam"),
|
||||
("llama4_text", "llama4"),
|
||||
("blip_2_qformer", "blip_2"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -53,6 +53,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("blenderbot-small", "BlenderbotSmallModel"),
|
||||
("blip", "BlipModel"),
|
||||
("blip-2", "Blip2Model"),
|
||||
("blip_2_qformer", "Blip2QFormerModel"),
|
||||
("bloom", "BloomModel"),
|
||||
("bridgetower", "BridgeTowerModel"),
|
||||
("bros", "BrosModel"),
|
||||
|
@ -144,6 +144,8 @@ class Blip2QFormerConfig(PretrainedConfig):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Index to be used for padding token.
|
||||
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
||||
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
||||
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
||||
|
@ -456,6 +456,21 @@ BLIP_2_START_DOCSTRING = r"""
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
BLIP_2_QFORMER_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Blip2QFormerConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
BLIP_2_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
@ -621,6 +636,60 @@ BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
BLIP2_QFORMER_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Hidden states to be used in the attention computation. If cross-attention,
|
||||
will be used for the query (i.e., key and value will use the encoder_hidden_states).
|
||||
|
||||
query_length (`int`, *optional*):
|
||||
Length of the query, usually based on the number of query tokens.
|
||||
If no value is provided, query_length will be inferred by the query_embeds.
|
||||
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Attention mask of size `(batch, sequence_length)` where padding elements
|
||||
are indicated by 0.
|
||||
|
||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
|
||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
||||
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
||||
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
||||
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
||||
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
||||
`(batch_size, sequence_length)`.
|
||||
|
||||
use_cache (`bool`, `optional`):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
|
||||
class Blip2Encoder(nn.Module):
|
||||
"""
|
||||
@ -1248,11 +1317,13 @@ class Blip2TextEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
BLIP-2 Querying Transformer (Q-Former).
|
||||
""",
|
||||
BLIP_2_QFORMER_START_DOCSTRING,
|
||||
)
|
||||
class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
"""
|
||||
Querying Transformer (Q-Former), used in BLIP-2.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
@ -1323,6 +1394,10 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLIP2_QFORMER_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(
|
||||
output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=Blip2QFormerConfig
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
query_embeds: torch.FloatTensor,
|
||||
@ -1338,23 +1413,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
||||
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
||||
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
||||
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
||||
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
||||
`(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, `optional`):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
Returns:
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
@ -106,7 +106,6 @@ OBJECTS_TO_IGNORE = [
|
||||
"BlenderbotSmallConfig",
|
||||
"BlenderbotSmallTokenizerFast",
|
||||
"BlenderbotTokenizerFast",
|
||||
"Blip2QFormerConfig",
|
||||
"Blip2VisionConfig",
|
||||
"BlipTextConfig",
|
||||
"BlipVisionConfig",
|
||||
|
@ -187,7 +187,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"ClapAudioModelWithProjection",
|
||||
"Blip2TextModelWithProjection",
|
||||
"Blip2VisionModelWithProjection",
|
||||
"Blip2QFormerModel",
|
||||
"Blip2VisionModel",
|
||||
"ErnieMForInformationExtraction",
|
||||
"FastSpeech2ConformerHifiGan",
|
||||
|
Loading…
Reference in New Issue
Block a user