mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[Blip2
] Add Blip2Model
(#21817)
* add v1 * add `Blip2Model` - add relevant functions - add tests - add on automapping * fix docs * fix doctest
This commit is contained in:
parent
ae9230af40
commit
b8de7e448e
@ -71,6 +71,14 @@ If you're interested in submitting a resource to be included here, please feel f
|
||||
[[autodoc]] Blip2QFormerModel
|
||||
- forward
|
||||
|
||||
## Blip2Model
|
||||
|
||||
[[autodoc]] Blip2Model
|
||||
- forward
|
||||
- get_text_features
|
||||
- get_image_features
|
||||
- get_qformer_features
|
||||
|
||||
## Blip2ForConditionalGeneration
|
||||
|
||||
[[autodoc]] Blip2ForConditionalGeneration
|
||||
|
@ -1191,6 +1191,7 @@ else:
|
||||
[
|
||||
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Blip2ForConditionalGeneration",
|
||||
"Blip2Model",
|
||||
"Blip2PreTrainedModel",
|
||||
"Blip2QFormerModel",
|
||||
"Blip2VisionModel",
|
||||
@ -4651,6 +4652,7 @@ if TYPE_CHECKING:
|
||||
from .models.blip_2 import (
|
||||
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2Model,
|
||||
Blip2PreTrainedModel,
|
||||
Blip2QFormerModel,
|
||||
Blip2VisionModel,
|
||||
|
@ -42,6 +42,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("blenderbot", "BlenderbotModel"),
|
||||
("blenderbot-small", "BlenderbotSmallModel"),
|
||||
("blip", "BlipModel"),
|
||||
("blip_2", "Blip2Model"),
|
||||
("bloom", "BloomModel"),
|
||||
("bridgetower", "BridgeTowerModel"),
|
||||
("camembert", "CamembertModel"),
|
||||
|
@ -34,6 +34,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["modeling_blip_2"] = [
|
||||
"BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Blip2Model",
|
||||
"Blip2QFormerModel",
|
||||
"Blip2PreTrainedModel",
|
||||
"Blip2ForConditionalGeneration",
|
||||
@ -58,6 +59,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_blip_2 import (
|
||||
BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Blip2ForConditionalGeneration,
|
||||
Blip2Model,
|
||||
Blip2PreTrainedModel,
|
||||
Blip2QFormerModel,
|
||||
Blip2VisionModel,
|
||||
|
@ -342,6 +342,43 @@ BLIP_2_VISION_INPUTS_DOCSTRING = r"""
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
BLIP_2_TEXT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
||||
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
||||
|
||||
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
|
||||
Training](./t5#training).
|
||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
BLIP_2_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
@ -1171,6 +1208,337 @@ class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
BLIP-2 Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
|
||||
(Q-Former) and a language model.
|
||||
""",
|
||||
BLIP_2_START_DOCSTRING,
|
||||
)
|
||||
class Blip2Model(Blip2PreTrainedModel):
|
||||
config_class = Blip2Config
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = Blip2VisionModel(config.vision_config)
|
||||
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config)
|
||||
|
||||
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
||||
if config.use_decoder_only_language_model:
|
||||
language_model = AutoModelForCausalLM.from_config(config.text_config)
|
||||
else:
|
||||
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
|
||||
self.language_model = language_model
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLIP_2_TEXT_INPUTS_DOCSTRING)
|
||||
def get_text_features(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_ids: Optional[torch.Tensor] = None,
|
||||
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
|
||||
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
|
||||
contains the language model logits, the past key values and the hidden states if
|
||||
`output_hidden_states=True`.
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoTokenizer, Blip2Model
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
||||
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt").to(device)
|
||||
>>> text_features = model.get_text_features(**inputs)
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
text_outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
text_outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
return text_outputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLIP_2_VISION_INPUTS_DOCSTRING)
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
|
||||
The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
|
||||
contains the image features, the pooled image features and the hidden states if
|
||||
`output_hidden_states=True`.
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Blip2Model
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
||||
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
||||
>>> image_outputs = model.get_image_features(**inputs)
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
return vision_outputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING)
|
||||
def get_qformer_features(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
|
||||
The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
|
||||
contains the image features, the pooled image features and the hidden states if
|
||||
`output_hidden_states=True`.
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import Blip2Processor, Blip2Model
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
|
||||
>>> qformer_outputs = model.get_qformer_features(**inputs)
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_outputs = self.qformer(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
return query_outputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLIP_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Blip2ForConditionalGenerationModelOutput, config_class=Blip2VisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import Blip2Processor, Blip2Model
|
||||
>>> import torch
|
||||
|
||||
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
>>> processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
||||
>>> model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
|
||||
>>> model.to(device) # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> prompt = "Question: how many cats are there? Answer:"
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_outputs = self.qformer(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
query_output = query_outputs[0]
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds], dim=1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
expected_device = language_model_attention_mask.device
|
||||
attention_mask = torch.cat([language_model_attention_mask, attention_mask.to(expected_device)], dim=1)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
||||
if labels is not None:
|
||||
logits = logits[:, -labels.size(1) :, :]
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
||||
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction="mean")
|
||||
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
||||
else:
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
)
|
||||
loss = outputs.loss if return_dict else outputs[0]
|
||||
logits = outputs.logits if return_dict else outputs[1]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, vision_outputs, query_outputs, outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Blip2ForConditionalGenerationModelOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
vision_outputs=vision_outputs,
|
||||
qformer_outputs=query_outputs,
|
||||
language_model_outputs=outputs,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
BLIP-2 Model for generating text given an image and an optional text prompt. The model consists of a vision
|
||||
|
@ -1221,6 +1221,13 @@ class Blip2ForConditionalGeneration(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Blip2Model(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Blip2PreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -40,7 +40,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import Blip2ForConditionalGeneration, Blip2VisionModel
|
||||
from transformers import Blip2ForConditionalGeneration, Blip2Model, Blip2VisionModel
|
||||
from transformers.models.blip_2.modeling_blip_2 import BLIP_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@ -664,8 +664,8 @@ class Blip2ForConditionalGenerationModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
|
||||
class Blip2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
@ -737,6 +737,56 @@ class Blip2ForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = Blip2ForConditionalGeneration.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_get_text_features(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device),
|
||||
"attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).to(torch_device),
|
||||
"decoder_input_ids": torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]).to(torch_device),
|
||||
}
|
||||
|
||||
model = Blip2Model(config).to(torch_device)
|
||||
model.eval()
|
||||
text_features = model.get_text_features(**inputs_dict)
|
||||
self.assertEqual(text_features[0].shape, (1, 10, config.text_config.vocab_size))
|
||||
|
||||
def test_get_image_features(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
|
||||
|
||||
for key in keys_to_pop:
|
||||
inputs_dict.pop(key)
|
||||
|
||||
model = Blip2Model(config).to(torch_device)
|
||||
model.eval()
|
||||
image_features = model.get_image_features(**inputs_dict)
|
||||
self.assertEqual(
|
||||
image_features[0].shape,
|
||||
(
|
||||
self.model_tester.vision_model_tester.batch_size,
|
||||
self.model_tester.vision_model_tester.seq_length,
|
||||
config.vision_config.hidden_size,
|
||||
),
|
||||
)
|
||||
|
||||
def test_get_qformer_features(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
keys_to_pop = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"]
|
||||
|
||||
for key in keys_to_pop:
|
||||
inputs_dict.pop(key)
|
||||
|
||||
model = Blip2Model(config).to(torch_device)
|
||||
model.eval()
|
||||
qformer_features = model.get_qformer_features(**inputs_dict)
|
||||
self.assertEqual(
|
||||
qformer_features[0].shape,
|
||||
(self.model_tester.vision_model_tester.batch_size, 10, config.vision_config.hidden_size),
|
||||
)
|
||||
|
||||
# override from common to deal with nested configurations (`vision_config`, `text_config` and `qformer_config`)
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user