Expose blip2qformer (#37254)

* Expose blip2qformer

* Add missing args to blip2 config
This commit is contained in:
Alex Brooks 2025-04-08 04:04:33 -06:00 committed by GitHub
parent 2da82e432d
commit 2515a5a290
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 86 additions and 23 deletions

View File

@ -54,6 +54,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("blenderbot-small", "BlenderbotSmallConfig"),
("blip", "BlipConfig"),
("blip-2", "Blip2Config"),
("blip_2_qformer", "Blip2QFormerConfig"),
("bloom", "BloomConfig"),
("bridgetower", "BridgeTowerConfig"),
("bros", "BrosConfig"),
@ -391,6 +392,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("blenderbot-small", "BlenderbotSmall"),
("blip", "BLIP"),
("blip-2", "BLIP-2"),
("blip_2_qformer", "BLIP-2 QFormer"),
("bloom", "BLOOM"),
("bort", "BORT"),
("bridgetower", "BridgeTower"),
@ -781,6 +783,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("granitevision", "llava_next"),
("sam_vision_model", "sam"),
("llama4_text", "llama4"),
("blip_2_qformer", "blip_2"),
]
)

View File

@ -53,6 +53,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("blenderbot-small", "BlenderbotSmallModel"),
("blip", "BlipModel"),
("blip-2", "Blip2Model"),
("blip_2_qformer", "Blip2QFormerModel"),
("bloom", "BloomModel"),
("bridgetower", "BridgeTowerModel"),
("bros", "BrosModel"),

View File

@ -144,6 +144,8 @@ class Blip2QFormerConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 0):
Index to be used for padding token.
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to

View File

@ -456,6 +456,21 @@ BLIP_2_START_DOCSTRING = r"""
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BLIP_2_QFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Blip2QFormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BLIP_2_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
@ -621,6 +636,60 @@ BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
"""
BLIP2_QFORMER_INPUTS_DOCSTRING = r"""
Args:
query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Hidden states to be used in the attention computation. If cross-attention,
will be used for the query (i.e., key and value will use the encoder_hidden_states).
query_length (`int`, *optional*):
Length of the query, usually based on the number of query tokens.
If no value is provided, query_length will be inferred by the query_embeds.
attention_mask (`torch.FloatTensor`, *optional*):
Attention mask of size `(batch, sequence_length)` where padding elements
are indicated by 0.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
`(batch_size, sequence_length)`.
use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
class Blip2Encoder(nn.Module):
"""
@ -1248,11 +1317,13 @@ class Blip2TextEmbeddings(nn.Module):
return embeddings
@add_start_docstrings(
"""
BLIP-2 Querying Transformer (Q-Former).
""",
BLIP_2_QFORMER_START_DOCSTRING,
)
class Blip2QFormerModel(Blip2PreTrainedModel):
"""
Querying Transformer (Q-Former), used in BLIP-2.
"""
def __init__(self, config: Blip2QFormerConfig):
super().__init__(config)
self.config = config
@ -1323,6 +1394,10 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
@add_start_docstrings_to_model_forward(BLIP2_QFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=Blip2QFormerConfig
)
def forward(
self,
query_embeds: torch.FloatTensor,
@ -1338,23 +1413,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
`(batch_size, sequence_length)`.
use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (

View File

@ -106,7 +106,6 @@ OBJECTS_TO_IGNORE = [
"BlenderbotSmallConfig",
"BlenderbotSmallTokenizerFast",
"BlenderbotTokenizerFast",
"Blip2QFormerConfig",
"Blip2VisionConfig",
"BlipTextConfig",
"BlipVisionConfig",

View File

@ -187,7 +187,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"ClapAudioModelWithProjection",
"Blip2TextModelWithProjection",
"Blip2VisionModelWithProjection",
"Blip2QFormerModel",
"Blip2VisionModel",
"ErnieMForInformationExtraction",
"FastSpeech2ConformerHifiGan",