mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add video modality for InstrucBLIP (#30182)
* squash in single commit * add docs * dummy obj * more changes in diff converter * tiny fix * make docs happy * skip test * repo consistency tests * update docstring * style * fix tests * change diff imports * [run-slow] instructblipvideo * [run-slow] instructblipvideo * fix tests and remove logit check * [run-slow] instructblipvideo
This commit is contained in:
parent
a958c4a801
commit
fc689d75a0
@ -776,6 +776,8 @@
|
||||
title: Idefics2
|
||||
- local: model_doc/instructblip
|
||||
title: InstructBLIP
|
||||
- local: model_doc/instructblipvideo
|
||||
title: InstructBlipVideo
|
||||
- local: model_doc/kosmos-2
|
||||
title: KOSMOS-2
|
||||
- local: model_doc/layoutlm
|
||||
|
@ -165,6 +165,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [ImageGPT](model_doc/imagegpt) | ✅ | ❌ | ❌ |
|
||||
| [Informer](model_doc/informer) | ✅ | ❌ | ❌ |
|
||||
| [InstructBLIP](model_doc/instructblip) | ✅ | ❌ | ❌ |
|
||||
| [InstructBlipVideo](model_doc/instructblipvideo) | ✅ | ❌ | ❌ |
|
||||
| [Jamba](model_doc/jamba) | ✅ | ❌ | ❌ |
|
||||
| [JetMoe](model_doc/jetmoe) | ✅ | ❌ | ❌ |
|
||||
| [Jukebox](model_doc/jukebox) | ✅ | ❌ | ❌ |
|
||||
|
@ -50,6 +50,7 @@ InstructBLIP uses the same architecture as [BLIP-2](blip2) with a tiny but impor
|
||||
|
||||
[[autodoc]] InstructBlipProcessor
|
||||
|
||||
|
||||
## InstructBlipVisionModel
|
||||
|
||||
[[autodoc]] InstructBlipVisionModel
|
||||
|
74
docs/source/en/model_doc/instructblipvideo.md
Normal file
74
docs/source/en/model_doc/instructblipvideo.md
Normal file
@ -0,0 +1,74 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# InstructBlipVideo
|
||||
|
||||
## Overview
|
||||
|
||||
## Overview
|
||||
|
||||
The InstructBLIPVideo is an extension of the models proposed in [InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning](https://arxiv.org/abs/2305.06500) by Wenliang Dai, Junnan Li, Dongxu Li, Anthony Meng Huat Tiong, Junqi Zhao, Weisheng Wang, Boyang Li, Pascale Fung, Steven Hoi.
|
||||
InstructBLIPVideo uses the same architecture as [InstructBLIP](instructblip) and works with the same checkpoints as [InstructBLIP](instructblip). The only difference is the ability to process videos.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*General-purpose language models that can solve various language-domain tasks have emerged driven by the pre-training and instruction-tuning pipeline. However, building general-purpose vision-language models is challenging due to the increased task discrepancy introduced by the additional visual input. Although vision-language pre-training has been widely studied, vision-language instruction tuning remains relatively less explored. In this paper, we conduct a systematic and comprehensive study on vision-language instruction tuning based on the pre-trained BLIP-2 models. We gather a wide variety of 26 publicly available datasets, transform them into instruction tuning format and categorize them into two clusters for held-in instruction tuning and held-out zero-shot evaluation. Additionally, we introduce instruction-aware visual feature extraction, a crucial method that enables the model to extract informative features tailored to the given instruction. The resulting InstructBLIP models achieve state-of-the-art zero-shot performance across all 13 held-out datasets, substantially outperforming BLIP-2 and the larger Flamingo. Our models also lead to state-of-the-art performance when finetuned on individual downstream tasks (e.g., 90.7% accuracy on ScienceQA IMG). Furthermore, we qualitatively demonstrate the advantages of InstructBLIP over concurrent multimodal models.*
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/instructblip_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> InstructBLIPVideo architecture. Taken from the <a href="https://arxiv.org/abs/2305.06500">original paper.</a> </small>
|
||||
|
||||
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
|
||||
The original code can be found [here](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip).
|
||||
|
||||
## Usage tips
|
||||
|
||||
- The model was trained by sampling 4 frames per video, so it's recommended to sample 4 frames
|
||||
|
||||
## InstructBlipVideoConfig
|
||||
|
||||
[[autodoc]] InstructBlipVideoConfig
|
||||
- from_vision_qformer_text_configs
|
||||
|
||||
## InstructBlipVideoVisionConfig
|
||||
|
||||
[[autodoc]] InstructBlipVideoVisionConfig
|
||||
|
||||
## InstructBlipVideoQFormerConfig
|
||||
|
||||
[[autodoc]] InstructBlipVideoQFormerConfig
|
||||
|
||||
## InstructBlipVideoProcessor
|
||||
|
||||
[[autodoc]] InstructBlipVideoProcessor
|
||||
|
||||
## InstructBlipVideoImageProcessor
|
||||
|
||||
[[autodoc]] InstructBlipVideoImageProcessor
|
||||
- preprocess
|
||||
|
||||
## InstructBlipVideoVisionModel
|
||||
|
||||
[[autodoc]] InstructBlipVideoVisionModel
|
||||
- forward
|
||||
|
||||
## InstructBlipVideoQFormerModel
|
||||
|
||||
[[autodoc]] InstructBlipVideoQFormerModel
|
||||
- forward
|
||||
|
||||
## InstructBlipVideoForConditionalGeneration
|
||||
|
||||
[[autodoc]] InstructBlipVideoForConditionalGeneration
|
||||
- forward
|
||||
- generate
|
@ -473,6 +473,12 @@ _import_structure = {
|
||||
"InstructBlipQFormerConfig",
|
||||
"InstructBlipVisionConfig",
|
||||
],
|
||||
"models.instructblipvideo": [
|
||||
"InstructBlipVideoConfig",
|
||||
"InstructBlipVideoProcessor",
|
||||
"InstructBlipVideoQFormerConfig",
|
||||
"InstructBlipVideoVisionConfig",
|
||||
],
|
||||
"models.jamba": ["JambaConfig"],
|
||||
"models.jetmoe": ["JetMoeConfig"],
|
||||
"models.kosmos2": [
|
||||
@ -1137,6 +1143,7 @@ else:
|
||||
_import_structure["models.idefics"].extend(["IdeficsImageProcessor"])
|
||||
_import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"])
|
||||
_import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"])
|
||||
_import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"])
|
||||
_import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"])
|
||||
_import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"])
|
||||
_import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"])
|
||||
@ -2318,6 +2325,14 @@ else:
|
||||
"InstructBlipVisionModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.instructblipvideo"].extend(
|
||||
[
|
||||
"InstructBlipVideoForConditionalGeneration",
|
||||
"InstructBlipVideoPreTrainedModel",
|
||||
"InstructBlipVideoQFormerModel",
|
||||
"InstructBlipVideoVisionModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.jamba"].extend(
|
||||
[
|
||||
"JambaForCausalLM",
|
||||
@ -5079,6 +5094,12 @@ if TYPE_CHECKING:
|
||||
InstructBlipQFormerConfig,
|
||||
InstructBlipVisionConfig,
|
||||
)
|
||||
from .models.instructblipvideo import (
|
||||
InstructBlipVideoConfig,
|
||||
InstructBlipVideoProcessor,
|
||||
InstructBlipVideoQFormerConfig,
|
||||
InstructBlipVideoVisionConfig,
|
||||
)
|
||||
from .models.jamba import JambaConfig
|
||||
from .models.jetmoe import JetMoeConfig
|
||||
from .models.kosmos2 import (
|
||||
@ -5772,6 +5793,7 @@ if TYPE_CHECKING:
|
||||
from .models.idefics import IdeficsImageProcessor
|
||||
from .models.idefics2 import Idefics2ImageProcessor
|
||||
from .models.imagegpt import ImageGPTFeatureExtractor, ImageGPTImageProcessor
|
||||
from .models.instructblipvideo import InstructBlipVideoImageProcessor
|
||||
from .models.layoutlmv2 import (
|
||||
LayoutLMv2FeatureExtractor,
|
||||
LayoutLMv2ImageProcessor,
|
||||
@ -6771,6 +6793,12 @@ if TYPE_CHECKING:
|
||||
InstructBlipQFormerModel,
|
||||
InstructBlipVisionModel,
|
||||
)
|
||||
from .models.instructblipvideo import (
|
||||
InstructBlipVideoForConditionalGeneration,
|
||||
InstructBlipVideoPreTrainedModel,
|
||||
InstructBlipVideoQFormerModel,
|
||||
InstructBlipVideoVisionModel,
|
||||
)
|
||||
from .models.jamba import (
|
||||
JambaForCausalLM,
|
||||
JambaForSequenceClassification,
|
||||
|
@ -111,6 +111,7 @@ from . import (
|
||||
imagegpt,
|
||||
informer,
|
||||
instructblip,
|
||||
instructblipvideo,
|
||||
jamba,
|
||||
jetmoe,
|
||||
kosmos2,
|
||||
|
@ -128,6 +128,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("imagegpt", "ImageGPTConfig"),
|
||||
("informer", "InformerConfig"),
|
||||
("instructblip", "InstructBlipConfig"),
|
||||
("instructblipvideo", "InstructBlipVideoConfig"),
|
||||
("jamba", "JambaConfig"),
|
||||
("jetmoe", "JetMoeConfig"),
|
||||
("jukebox", "JukeboxConfig"),
|
||||
@ -404,6 +405,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("imagegpt", "ImageGPT"),
|
||||
("informer", "Informer"),
|
||||
("instructblip", "InstructBLIP"),
|
||||
("instructblipvideo", "InstructBlipVideo"),
|
||||
("jamba", "Jamba"),
|
||||
("jetmoe", "JetMoe"),
|
||||
("jukebox", "Jukebox"),
|
||||
|
@ -89,6 +89,7 @@ else:
|
||||
("idefics2", ("Idefics2ImageProcessor",)),
|
||||
("imagegpt", ("ImageGPTImageProcessor",)),
|
||||
("instructblip", ("BlipImageProcessor",)),
|
||||
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
|
||||
("kosmos-2", ("CLIPImageProcessor",)),
|
||||
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
|
||||
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
|
||||
@ -156,7 +157,6 @@ for model_type, image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.items():
|
||||
|
||||
IMAGE_PROCESSOR_MAPPING_NAMES[model_type] = (slow_image_processor_class, fast_image_processor_class)
|
||||
|
||||
|
||||
IMAGE_PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, IMAGE_PROCESSOR_MAPPING_NAMES)
|
||||
|
||||
|
||||
|
@ -697,6 +697,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
("git", "GitForCausalLM"),
|
||||
("idefics2", "Idefics2ForConditionalGeneration"),
|
||||
("instructblip", "InstructBlipForConditionalGeneration"),
|
||||
("instructblipvideo", "InstructBlipVideoForConditionalGeneration"),
|
||||
("kosmos-2", "Kosmos2ForConditionalGeneration"),
|
||||
("llava", "LlavaForConditionalGeneration"),
|
||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
||||
|
@ -64,6 +64,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("idefics", "IdeficsProcessor"),
|
||||
("idefics2", "Idefics2Processor"),
|
||||
("instructblip", "InstructBlipProcessor"),
|
||||
("instructblipvideo", "InstructBlipVideoProcessor"),
|
||||
("kosmos-2", "Kosmos2Processor"),
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
("layoutlmv3", "LayoutLMv3Processor"),
|
||||
|
@ -205,6 +205,7 @@ else:
|
||||
("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("instructblipvideo", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"jamba",
|
||||
(
|
||||
|
@ -317,7 +317,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, Blip2VisionEmbeddings):
|
||||
if hasattr(self.config, "vision_config"):
|
||||
if hasattr(self.config, "vision_config") and not isinstance(self.config, Blip2VisionConfig):
|
||||
factor = self.config.vision_config.initializer_range
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
|
@ -164,6 +164,8 @@ class InstructBlipQFormerConfig(PretrainedConfig):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Token id used for padding sequences.
|
||||
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
||||
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
||||
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
||||
|
@ -324,7 +324,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
if isinstance(module, InstructBlipVisionEmbeddings):
|
||||
if hasattr(self.config, "vision_config"):
|
||||
if hasattr(self.config, "vision_config") and not isinstance(self.config, InstructBlipVisionConfig):
|
||||
factor = self.config.vision_config.initializer_range
|
||||
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
||||
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
||||
|
83
src/transformers/models/instructblipvideo/__init__.py
Normal file
83
src/transformers/models/instructblipvideo/__init__.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_instructblipvideo": [
|
||||
"InstructBlipVideoConfig",
|
||||
"InstructBlipVideoQFormerConfig",
|
||||
"InstructBlipVideoVisionConfig",
|
||||
],
|
||||
"processing_instructblipvideo": ["InstructBlipVideoProcessor"],
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["image_processing_instructblipvideo"] = ["InstructBlipVideoImageProcessor"]
|
||||
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_instructblipvideo"] = [
|
||||
"InstructBlipVideoQFormerModel",
|
||||
"InstructBlipVideoPreTrainedModel",
|
||||
"InstructBlipVideoForConditionalGeneration",
|
||||
"InstructBlipVideoVisionModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_instructblipvideo import (
|
||||
InstructBlipVideoConfig,
|
||||
InstructBlipVideoQFormerConfig,
|
||||
InstructBlipVideoVisionConfig,
|
||||
)
|
||||
from .processing_instructblipvideo import InstructBlipVideoProcessor
|
||||
|
||||
try:
|
||||
if not is_vision_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .image_processing_instructblipvideo import InstructBlipVideoImageProcessor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_instructblipvideo import (
|
||||
InstructBlipVideoForConditionalGeneration,
|
||||
InstructBlipVideoPreTrainedModel,
|
||||
InstructBlipVideoQFormerModel,
|
||||
InstructBlipVideoVisionModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
@ -0,0 +1,364 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from <path_to_diff_file.py>.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the diff. If any change should be done, please apply the change to the
|
||||
# diff.py file directly.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ..auto import CONFIG_MAPPING
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class InstructBlipVideoVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InstructBlipVideoVisionModel`]. It is used to
|
||||
instantiate a Instructblipvideo vision encoder according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration defaults will yield a similar configuration to that of the Instructblipvideo
|
||||
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1408):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 6144):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 39):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. to 1e-5): The epsilon used by the layer
|
||||
normalization layers.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 1e-10):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the queries and values in the self-attention layers.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import InstructBlipVideoVisionConfig, InstructBlipVideoVisionModel
|
||||
|
||||
>>> # Initializing a InstructBlipVideoVisionConfig with Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> configuration = InstructBlipVideoVisionConfig()
|
||||
|
||||
>>> # Initializing a InstructBlipVideoVisionModel (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> model = InstructBlipVideoVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "instructblipvideo_vision_model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1408,
|
||||
intermediate_size=6144,
|
||||
num_hidden_layers=39,
|
||||
num_attention_heads=16,
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=1e-10,
|
||||
qkv_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.initializer_range = initializer_range
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.qkv_bias = qkv_bias
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the vision config dict if we are loading from InstructBlipVideoConfig
|
||||
if config_dict.get("model_type") == "instructblipvideo":
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InstructBlipVideoQFormerModel`]. It is used to
|
||||
instantiate a Instructblipvideo Querying Transformer (Q-Former) model according to the specified arguments, defining the
|
||||
model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the Instructblipvideo [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
|
||||
architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
||||
Read the documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Note that [`InstructBlipVideoQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 30522):
|
||||
Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by
|
||||
the `inputs_ids` passed when calling the model.
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
Token id used for padding sequences.
|
||||
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
||||
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
||||
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
||||
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
||||
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
||||
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
||||
cross_attention_frequency (`int`, *optional*, defaults to 2):
|
||||
The frequency of adding cross-attention to the Transformer layers.
|
||||
encoder_hidden_size (`int`, *optional*, defaults to 1408):
|
||||
The hidden size of the hidden states for cross-attention.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import InstructBlipVideoQFormerConfig, InstructBlipVideoQFormerModel
|
||||
|
||||
>>> # Initializing a Instructblipvideo Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> configuration = InstructBlipVideoQFormerConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> model = InstructBlipVideoQFormerModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "instructblipvideo_qformer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
position_embedding_type="absolute",
|
||||
cross_attention_frequency=2,
|
||||
encoder_hidden_size=1408,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self.cross_attention_frequency = cross_attention_frequency
|
||||
self.encoder_hidden_size = encoder_hidden_size
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
# get the qformer config dict if we are loading from InstructBlipVideoConfig
|
||||
if config_dict.get("model_type") == "instructblipvideo":
|
||||
config_dict = config_dict["qformer_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class InstructBlipVideoConfig(PretrainedConfig):
|
||||
r"""
|
||||
[`InstructBlipVideoConfig`] is the configuration class to store the configuration of a
|
||||
[`InstructBlipVideoForConditionalGeneration`]. It is used to instantiate a Instructblipvideo model according to the specified
|
||||
arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
|
||||
the defaults will yield a similar configuration to that of the Instructblipvideo
|
||||
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`InstructBlipVideoVisionConfig`].
|
||||
qformer_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize [`InstructBlipVideoQFormerConfig`].
|
||||
text_config (`dict`, *optional*):
|
||||
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
||||
num_query_tokens (`int`, *optional*, defaults to 32):
|
||||
The number of query tokens passed through the Transformer.
|
||||
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import (
|
||||
... InstructBlipVideoVisionConfig,
|
||||
... InstructBlipVideoQFormerConfig,
|
||||
... OPTConfig,
|
||||
... InstructBlipVideoConfig,
|
||||
... InstructBlipVideoForConditionalGeneration,
|
||||
... )
|
||||
|
||||
>>> # Initializing a InstructBlipVideoConfig with Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> configuration = InstructBlipVideoConfig()
|
||||
|
||||
>>> # Initializing a InstructBlipVideoForConditionalGeneration (with random weights) from the Salesforce/instruct-blip-flan-t5 style configuration
|
||||
>>> model = InstructBlipVideoForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
|
||||
>>> # We can also initialize a InstructBlipVideoConfig from a InstructBlipVideoVisionConfig, InstructBlipVideoQFormerConfig and any PretrainedConfig
|
||||
|
||||
>>> # Initializing Instructblipvideo vision, Instructblipvideo Q-Former and language model configurations
|
||||
>>> vision_config = InstructBlipVideoVisionConfig()
|
||||
>>> qformer_config = InstructBlipVideoQFormerConfig()
|
||||
>>> text_config = OPTConfig()
|
||||
|
||||
>>> config = InstructBlipVideoConfig.from_text_vision_configs(vision_config, qformer_config, text_config)
|
||||
```"""
|
||||
|
||||
model_type = "instructblipvideo"
|
||||
|
||||
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
logger.info("vision_config is None. initializing the InstructBlipVideoVisionConfig with default values.")
|
||||
|
||||
if qformer_config is None:
|
||||
qformer_config = {}
|
||||
logger.info("qformer_config is None. Initializing the InstructBlipVideoQFormerConfig with default values.")
|
||||
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
|
||||
|
||||
self.vision_config = InstructBlipVideoVisionConfig(**vision_config)
|
||||
self.qformer_config = InstructBlipVideoQFormerConfig(**qformer_config)
|
||||
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
|
||||
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
||||
|
||||
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
||||
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
||||
|
||||
self.num_query_tokens = num_query_tokens
|
||||
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
||||
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
self.initializer_factor = 1.0
|
||||
self.initializer_range = 0.02
|
||||
|
||||
@classmethod
|
||||
def from_vision_qformer_text_configs(
|
||||
cls,
|
||||
vision_config: InstructBlipVideoVisionConfig,
|
||||
qformer_config: InstructBlipVideoQFormerConfig,
|
||||
text_config: PretrainedConfig,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`InstructBlipVideoConfig`] (or a derived class) from a Instructblipvideo vision model, Q-Former and
|
||||
language model configurations.
|
||||
|
||||
Returns:
|
||||
[`InstructBlipVideoConfig`]: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(
|
||||
vision_config=vision_config.to_dict(),
|
||||
qformer_config=qformer_config.to_dict(),
|
||||
text_config=text_config.to_dict(),
|
||||
**kwargs,
|
||||
)
|
@ -0,0 +1,305 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Convert InstructBlipVideo checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/salesforce/LAVIS/tree/main/projects/instructblipvideo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
# pip3 install salesforce-lavis
|
||||
# I'm actually installing a slightly modified version: pip3 install git+https://github.com/nielsrogge/LAVIS.git@fix_lavis_float32 (there's also the fix_lavis branch)
|
||||
# also note: to convert Vicuna checkpoints, we had to include /home/niels/python_projects/checkpoints/FastChat/vicuna-7b in lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml
|
||||
# same for Vicuna-13b
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BlipImageProcessor,
|
||||
InstructBlipProcessor,
|
||||
InstructBlipVideoConfig,
|
||||
InstructBlipVideoForConditionalGeneration,
|
||||
InstructBlipVideoQFormerConfig,
|
||||
InstructBlipVideoVisionConfig,
|
||||
LlamaConfig,
|
||||
LlamaTokenizerFast,
|
||||
T5Config,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
|
||||
|
||||
def load_demo_image():
|
||||
url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config):
|
||||
rename_keys = []
|
||||
# fmt: off
|
||||
|
||||
# vision encoder
|
||||
rename_keys.append(("visual_encoder.cls_token", "vision_model.embeddings.class_embedding"))
|
||||
rename_keys.append(("visual_encoder.pos_embed", "vision_model.embeddings.position_embedding"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.weight", "vision_model.embeddings.patch_embedding.weight"))
|
||||
rename_keys.append(("visual_encoder.patch_embed.proj.bias", "vision_model.embeddings.patch_embedding.bias"))
|
||||
rename_keys.append(("ln_vision.weight", "vision_model.post_layernorm.weight"))
|
||||
rename_keys.append(("ln_vision.bias", "vision_model.post_layernorm.bias"))
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.weight", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm1.bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.weight", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.norm2.bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.qkv.weight", f"vision_model.encoder.layers.{i}.self_attn.qkv.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.weight", f"vision_model.encoder.layers.{i}.self_attn.projection.weight",))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.attn.proj.bias", f"vision_model.encoder.layers.{i}.self_attn.projection.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.weight", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc1.bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.weight", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
|
||||
rename_keys.append((f"visual_encoder.blocks.{i}.mlp.fc2.bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
|
||||
|
||||
# QFormer
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.weight", "qformer.embeddings.layernorm.weight"))
|
||||
rename_keys.append(("Qformer.bert.embeddings.LayerNorm.bias", "qformer.embeddings.layernorm.bias"))
|
||||
|
||||
# fmt: on
|
||||
return rename_keys
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
def read_in_q_v_bias(state_dict, config):
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
# read in original q and v biases
|
||||
q_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"visual_encoder.blocks.{i}.attn.v_bias")
|
||||
|
||||
# next, set bias in the state dict
|
||||
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
||||
state_dict[f"vision_model.encoder.layers.{i}.self_attn.qkv.bias"] = qkv_bias
|
||||
|
||||
|
||||
def get_blip2_config(model_name):
|
||||
image_size = 364 if "coco" in model_name else 224
|
||||
vision_config = InstructBlipVideoVisionConfig(image_size=image_size).to_dict()
|
||||
|
||||
# make sure the models have proper bos_token_id and eos_token_id set (important for generation)
|
||||
# seems like flan-T5 models don't have bos_token_id properly set?
|
||||
if "t5-xl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "t5-xxl" in model_name:
|
||||
text_config = T5Config.from_pretrained("google/flan-t5-xxl", dense_act_fn="gelu", bos_token_id=1).to_dict()
|
||||
elif "vicuna-7b" in model_name:
|
||||
text_config = LlamaConfig.from_pretrained("decapoda-research/llama-7b-hf", vocab_size=32001).to_dict()
|
||||
elif "vicuna-13b" in model_name:
|
||||
text_config = LlamaConfig.from_pretrained("decapoda-research/llama-13b-hf", vocab_size=32001).to_dict()
|
||||
else:
|
||||
raise ValueError("Model name not supported")
|
||||
|
||||
# the authors add one special "[DEC]" token to the vocab of Q-Former, hence vocab size = 30522 + 1
|
||||
qformer_config = InstructBlipVideoQFormerConfig(vocab_size=30523).to_dict()
|
||||
config = InstructBlipVideoConfig(
|
||||
vision_config=vision_config, text_config=text_config, qformer_config=qformer_config
|
||||
)
|
||||
|
||||
return config, image_size
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_blip2_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to Transformers design.
|
||||
"""
|
||||
qformer_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", truncation_side="left")
|
||||
qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
|
||||
if "t5" in model_name:
|
||||
tokenizer = T5TokenizerFast.from_pretrained("google/flan-t5-xl", truncation_side="left")
|
||||
elif "vicuna" in model_name:
|
||||
# the following was used in the original implementation:
|
||||
# tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", use_fast=False, truncation_side="left")
|
||||
# tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
# tokenizer.add_special_tokens({"bos_token": "</s>"})
|
||||
# tokenizer.add_special_tokens({"eos_token": "</s>"})
|
||||
# tokenizer.add_special_tokens({"unk_token": "</s>"})
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
"huggyllama/llama-7b", truncation_side="left", bos_token="</s>", unk_token="</s>"
|
||||
)
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
config, image_size = get_blip2_config(model_name)
|
||||
hf_model = InstructBlipVideoForConditionalGeneration(config).eval()
|
||||
|
||||
model_name_to_original = {
|
||||
"instructblipvideo-vicuna-7b": ("blip2_vicuna_instruct", "vicuna7b"),
|
||||
"instructblipvideo-vicuna-13b": ("blip2_vicuna_instruct", "vicuna13b"),
|
||||
"instructblipvideo-flan-t5-xl": ("blip2_t5_instruct", "flant5xl"),
|
||||
"instructblipvideo-flan-t5-xxl": ("blip2_t5_instruct", "flant5xxl"),
|
||||
}
|
||||
|
||||
name, type = model_name_to_original[model_name]
|
||||
|
||||
# load original model
|
||||
print("Loading original model...")
|
||||
hf_model_device = "cuda:1" if torch.cuda.is_available() else "cpu"
|
||||
lavis_device = "cuda:2" if torch.cuda.is_available() else "cpu"
|
||||
original_model, vis_processors, _ = load_model_and_preprocess(
|
||||
name=name, model_type=type, is_eval=True, device=lavis_device
|
||||
)
|
||||
original_model.eval()
|
||||
print("Done!")
|
||||
|
||||
# update state dict keys
|
||||
state_dict = original_model.state_dict()
|
||||
rename_keys = create_rename_keys(config)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
|
||||
# some keys can be renamed efficiently
|
||||
for key, val in state_dict.copy().items():
|
||||
val = state_dict.pop(key)
|
||||
if key.startswith("Qformer.bert"):
|
||||
key = key.replace("Qformer.bert", "qformer")
|
||||
if "attention.self" in key:
|
||||
key = key.replace("self", "attention")
|
||||
if "llm_proj" in key:
|
||||
key = key.replace("llm_proj", "language_projection")
|
||||
if "t5_proj" in key:
|
||||
key = key.replace("t5_proj", "language_projection")
|
||||
if key.startswith("llm_model"):
|
||||
key = key.replace("llm_model", "language_model")
|
||||
if key.startswith("t5"):
|
||||
key = key.replace("t5", "language")
|
||||
state_dict[key] = val
|
||||
|
||||
# read in qv biases
|
||||
read_in_q_v_bias(state_dict, config)
|
||||
|
||||
# note: weights get loaded in torch.float32 by default
|
||||
hf_model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
image = load_demo_image()
|
||||
prompt = "What is unusual about this image?"
|
||||
|
||||
# create processor
|
||||
image_processor = BlipImageProcessor(
|
||||
size={"height": image_size, "width": image_size}, image_mean=OPENAI_CLIP_MEAN, image_std=OPENAI_CLIP_STD
|
||||
)
|
||||
processor = InstructBlipProcessor(
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
qformer_tokenizer=qformer_tokenizer,
|
||||
)
|
||||
inputs = processor(images=image, text=prompt, return_tensors="pt").to(hf_model_device)
|
||||
|
||||
# make sure processor creates exact same pixel values
|
||||
original_pixel_values = vis_processors["eval"](image).unsqueeze(0).to(lavis_device)
|
||||
pixel_values = inputs.pixel_values
|
||||
assert torch.allclose(original_pixel_values.to(pixel_values.device), pixel_values)
|
||||
|
||||
original_model.to(lavis_device)
|
||||
hf_model.to(hf_model_device)
|
||||
with torch.no_grad():
|
||||
if "vicuna" in model_name:
|
||||
original_logits = original_model({"image": original_pixel_values, "text_input": [prompt]}).logits
|
||||
logits = hf_model(**inputs).logits
|
||||
else:
|
||||
original_logits = original_model(
|
||||
{"image": original_pixel_values, "text_input": [prompt], "text_output": ["\n"]}
|
||||
).logits
|
||||
label_input_ids = tokenizer("\n", return_tensors="pt").input_ids.to(hf_model_device)
|
||||
labels = label_input_ids.masked_fill(label_input_ids == tokenizer.pad_token_id, -100)
|
||||
logits = hf_model(**inputs, labels=labels).logits
|
||||
|
||||
print("First values of original logits:", original_logits[0, :3, :3])
|
||||
print("First values of HF logits:", logits[0, :3, :3])
|
||||
|
||||
# assert values
|
||||
assert original_logits.shape == logits.shape
|
||||
atol = 1e-4 if "vicuna" in model_name else 1e-5
|
||||
assert torch.allclose(original_logits.to(logits.device), logits, atol=atol)
|
||||
print("Looks ok!")
|
||||
|
||||
print("Generating with original model...")
|
||||
original_outputs = original_model.generate({"image": original_pixel_values, "prompt": prompt}, num_beams=5)
|
||||
|
||||
# important: we need to cast the weights of the HF model to the appropriate type
|
||||
print("Generating with HF model...")
|
||||
outputs = hf_model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
num_beams=5,
|
||||
max_length=256,
|
||||
min_length=1,
|
||||
top_p=0.9,
|
||||
repetition_penalty=1.5,
|
||||
length_penalty=1.0,
|
||||
temperature=1,
|
||||
)
|
||||
if "vicuna" in model_name:
|
||||
# convert output id 0 to 2 (eos_token_id)
|
||||
# TODO add this in the generate method?
|
||||
outputs[outputs == 0] = 2
|
||||
print("Original generation:", original_outputs)
|
||||
output_text = processor.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_text = [text.strip() for text in output_text]
|
||||
print("HF generation:", output_text)
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
hf_model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
processor.push_to_hub(f"Salesforce/{model_name}")
|
||||
hf_model.push_to_hub(f"Salesforce/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
choices = [
|
||||
"instructblipvideo-vicuna-7b",
|
||||
"instructblipvideo-vicuna-13b",
|
||||
"instructblipvideo-flan-t5-xl",
|
||||
"instructblipvideo-flan-t5-xxl",
|
||||
]
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="instructblipvideo-flan-t5-xl",
|
||||
choices=choices,
|
||||
type=str,
|
||||
help="Path to hf config.json of model to convert",
|
||||
)
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model and processor to the hub after converting",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_blip2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -0,0 +1,430 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.models.instructblip.configuration_instructblip import (
|
||||
InstructBlipConfig,
|
||||
InstructBlipQFormerConfig,
|
||||
InstructBlipVisionConfig,
|
||||
)
|
||||
from transformers.models.instructblip.modeling_instructblip import (
|
||||
InstructBlipAttention,
|
||||
InstructBlipEncoder,
|
||||
InstructBlipEncoderLayer,
|
||||
InstructBlipForConditionalGeneration,
|
||||
InstructBlipForConditionalGenerationModelOutput,
|
||||
InstructBlipMLP,
|
||||
InstructBlipPreTrainedModel,
|
||||
InstructBlipQFormerAttention,
|
||||
InstructBlipQFormerEmbeddings,
|
||||
InstructBlipQFormerEncoder,
|
||||
InstructBlipQFormerIntermediate,
|
||||
InstructBlipQFormerLayer,
|
||||
InstructBlipQFormerModel,
|
||||
InstructBlipQFormerOutput,
|
||||
InstructBlipQFormerSelfOutput,
|
||||
InstructBlipVisionEmbeddings,
|
||||
InstructBlipVisionModel,
|
||||
)
|
||||
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class InstructBlipVideoVisionConfig(InstructBlipVisionConfig):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoConfig(InstructBlipConfig):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoVisionEmbeddings(InstructBlipVisionEmbeddings):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoAttention(InstructBlipAttention):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoMLP(InstructBlipMLP):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoEncoderLayer(InstructBlipEncoderLayer):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoEncoder(InstructBlipEncoder):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoVisionModel(InstructBlipVisionModel):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerSelfOutput(InstructBlipQFormerSelfOutput):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerAttention(InstructBlipQFormerAttention):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerIntermediate(InstructBlipQFormerIntermediate):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerOutput(InstructBlipQFormerOutput):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerLayer(InstructBlipQFormerLayer):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerEncoder(InstructBlipQFormerEncoder):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerEmbeddings(InstructBlipQFormerEmbeddings):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerModel(InstructBlipQFormerModel):
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
qformer_input_ids: torch.FloatTensor,
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
input_ids: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
|
||||
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
||||
config.vocab_size]`
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
|
||||
>>> import torch
|
||||
>>> from huggingface_hub import hf_hub_download
|
||||
>>> from av
|
||||
|
||||
>>> def read_video_pyav(container, indices):
|
||||
... '''
|
||||
... Decode the video with PyAV decoder.
|
||||
... Args:
|
||||
... container (`av.container.input.InputContainer`): PyAV container.
|
||||
... indices (`List[int]`): List of frame indices to decode.
|
||||
... Returns:
|
||||
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
||||
... '''
|
||||
... frames = []
|
||||
... container.seek(0)
|
||||
... start_index = indices[0]
|
||||
... end_index = indices[-1]
|
||||
... for i, frame in enumerate(container.decode(video=0)):
|
||||
... if i > end_index:
|
||||
... break
|
||||
... if i >= start_index and i in indices:
|
||||
... frames.append(frame)
|
||||
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||||
|
||||
>>> model = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto")
|
||||
>>> processor = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
||||
|
||||
>>> file_path = hf_hub_download(
|
||||
repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
|
||||
)
|
||||
>>> container = av.open(video_path)
|
||||
>>> # sample uniformly 4 frames from the videWhy is this video funny?o
|
||||
>>> total_frames = container.streams.video[0].frames
|
||||
>>> indices = np.arange(0, total_frames, total_frames / 4).astype(int)
|
||||
>>> clip = read_video_pyav(container, indices)
|
||||
|
||||
>>> prompt = "What is happening in the video?"
|
||||
>>> inputs = processor(videos=clip, text=prompt, return_tensors="pt").to(device)
|
||||
|
||||
>>> outputs = model.generate(
|
||||
... **inputs,
|
||||
... do_sample=False,
|
||||
... num_beams=5,
|
||||
... max_length=256,
|
||||
... repetition_penalty=1.5,
|
||||
... length_penalty=1.0,
|
||||
... )
|
||||
>>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
||||
>>> print(generated_text)
|
||||
"A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front"
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# step 1: forward the images through the vision encoder,
|
||||
# we process in a batched way, later unbatch it back (video has frames=4 always)
|
||||
batch_size, frames, channel, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
image_embeds = vision_outputs[0]
|
||||
|
||||
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
|
||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
query_output = query_outputs[0][:, : query_tokens.size(1), :]
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
logits = outputs.logits if return_dict else outputs[0]
|
||||
loss = None
|
||||
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
logits = logits[:, -labels.size(1) :, :]
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
||||
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss(reduction="mean")
|
||||
|
||||
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
||||
else:
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
labels=labels,
|
||||
)
|
||||
loss = outputs.loss if return_dict else outputs[0]
|
||||
logits = outputs.logits if return_dict else outputs[1]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, vision_outputs, query_outputs, outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return InstructBlipVideoForConditionalGenerationModelOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
vision_outputs=vision_outputs,
|
||||
qformer_outputs=query_outputs,
|
||||
language_model_outputs=outputs,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
qformer_input_ids: Optional[torch.LongTensor] = None,
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
"""
|
||||
Overrides `generate` function to be able to use the model as a conditional generator.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or
|
||||
(batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed.
|
||||
qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
The sequence used as a prompt to be fed to the Q-Former module.
|
||||
qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
Returns:
|
||||
captions (list): A list of strings of length batch_size * num_captions.
|
||||
"""
|
||||
if hasattr(self, "hf_device_map"):
|
||||
# preprocess for `accelerate`
|
||||
self._preprocess_accelerate()
|
||||
|
||||
# we process in a batched way, later unbatch it back (video has frames=4)
|
||||
batch_size, frames, channel, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
|
||||
|
||||
image_embeds = self.vision_model(
|
||||
pixel_values,
|
||||
return_dict=True,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
).last_hidden_state
|
||||
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
||||
if qformer_attention_mask is None:
|
||||
qformer_attention_mask = torch.ones_like(qformer_input_ids)
|
||||
|
||||
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
|
||||
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
|
||||
query_outputs = self.qformer(
|
||||
input_ids=qformer_input_ids,
|
||||
attention_mask=qformer_attention_mask,
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :]
|
||||
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
|
||||
# unbatch the embeddings back by moving frames to seq-len
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
language_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
input_ids = (
|
||||
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
||||
.repeat(batch_size, 1)
|
||||
.to(image_embeds.device)
|
||||
)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1)
|
||||
|
||||
# concatenate query embeddings with prompt embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
||||
|
||||
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
|
||||
# -1 is to account for the prepended BOS after `generate.`
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
|
||||
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
# this is a temporary workaround to be consistent with other generation models and
|
||||
# have BOS as the first token, even though under the hood we are calling LM with embeds
|
||||
if not self.language_model.config.is_encoder_decoder:
|
||||
# the InstructBLIP authors used inconsistent tokenizer/model files during training,
|
||||
# with the tokenizer's bos token being set to </s> which has ID=2,
|
||||
# whereas the model's text config has bos token id = 0
|
||||
bos_token_id = (
|
||||
2
|
||||
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
|
||||
else self.config.text_config.bos_token_id
|
||||
)
|
||||
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
|
||||
if not isinstance(outputs, torch.Tensor):
|
||||
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
|
||||
else:
|
||||
outputs = torch.cat([bos_tokens, outputs], dim=-1)
|
||||
|
||||
return outputs
|
@ -0,0 +1,362 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Image processor class for InstructBLIPVideo. Largely copy of Blip2Processor with addition of a video processing abilities
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
VideoInput,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
is_valid_image,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_kwargs,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_vision_available, logging
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def make_batched_videos(videos) -> List[VideoInput]:
|
||||
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
||||
return videos
|
||||
|
||||
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
||||
if isinstance(videos[0], PIL.Image.Image):
|
||||
return [videos]
|
||||
elif len(videos[0].shape) == 4:
|
||||
return [list(video) for video in videos]
|
||||
|
||||
elif is_valid_image(videos) and len(videos.shape) == 4:
|
||||
return [list(videos)]
|
||||
|
||||
raise ValueError(f"Could not make batched video from {videos}")
|
||||
|
||||
|
||||
# Copied from transformers.models.blip.image_processing_blip.BlipImageProcessor with Blip->InstructBlipVideo, BLIP->InstructBLIPVideo
|
||||
class InstructBlipVideoImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a InstructBLIPVideo image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||
`do_resize` parameter in the `preprocess` method.
|
||||
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
||||
`do_rescale` parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
||||
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 384, "width": 384}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self._valid_processor_keys = [
|
||||
"images",
|
||||
"do_resize",
|
||||
"size",
|
||||
"resample",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format",
|
||||
]
|
||||
|
||||
# Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Ignore copy
|
||||
def preprocess(
|
||||
self,
|
||||
images: VideoInput = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess a video or batch of images/videos.
|
||||
|
||||
Args:
|
||||
videos (`VideoInput`):
|
||||
Video frames to preprocess. Expects a single or batch of videos as a list of frames with pixel values
|
||||
ranging from 0 to 255. If passing in video with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the video.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Controls the size of the video after `resize`. The shortest edge of the image is resized to
|
||||
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
|
||||
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
|
||||
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the video. Only has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the video values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the video by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the video.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to normalize the video by if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to normalize the video by if `do_normalize` is set to `True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
||||
videos = make_batched_videos(images)
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
if not valid_images(videos):
|
||||
raise ValueError(
|
||||
"Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
pixel_values = [
|
||||
[
|
||||
self._preprocess_image(
|
||||
image=frame,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for frame in video
|
||||
]
|
||||
for video in videos
|
||||
]
|
||||
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors)
|
||||
return encoded_outputs
|
||||
|
||||
# Ignore copy
|
||||
def _preprocess_image(
|
||||
self,
|
||||
image: ImageInput = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.ndarray:
|
||||
# PIL RGBA images are converted to RGB
|
||||
if do_convert_rgb:
|
||||
image = convert_to_rgb(image)
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
image = to_numpy_array(image)
|
||||
|
||||
if is_scaled_image(image) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled video frames. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(image)
|
||||
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
|
||||
return image
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,170 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import VideoInput
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
||||
from ...utils import TensorType
|
||||
from ..auto import AutoTokenizer
|
||||
|
||||
|
||||
class InstructBlipVideoProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs an InstructBLIPVideo processor which wraps a InstructBLIP image processor and a LLaMa/T5 tokenizer into a single
|
||||
processor.
|
||||
|
||||
[`InstructBlipVideoProcessor`] offers all the functionalities of [`InstructBlipVideoImageProcessor`] and [`AutoTokenizer`]. See the
|
||||
docstring of [`~InstructBlipVideoProcessor.__call__`] and [`~InstructBlipVideoProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor (`InstructBlipVideoImageProcessor`):
|
||||
An instance of [`InstructBlipVideoImageProcessor`]. The image processor is a required input.
|
||||
tokenizer (`AutoTokenizer`):
|
||||
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
||||
qformer_tokenizer (`AutoTokenizer`):
|
||||
An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "InstructBlipVideoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor, tokenizer, qformer_tokenizer):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
# add QFormer tokenizer
|
||||
self.qformer_tokenizer = qformer_tokenizer
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: VideoInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = None,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_token_type_ids: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
This method uses [`InstructBlipVideoImageProcessor.__call__`] method to prepare image(s) or video(s) for the model, and
|
||||
[`BertTokenizerFast.__call__`] to prepare text for the model.
|
||||
|
||||
Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
encoding = BatchFeature()
|
||||
|
||||
if text is not None:
|
||||
text_encoding = self.tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
encoding.update(text_encoding)
|
||||
qformer_text_encoding = self.qformer_tokenizer(
|
||||
text=text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
||||
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
||||
|
||||
if images is not None:
|
||||
image_encoding = self.image_processor(images, return_tensors=return_tensors)
|
||||
encoding.update(image_encoding)
|
||||
|
||||
return encoding
|
||||
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
# overwrite to save the Q-Former tokenizer in a separate folder
|
||||
def save_pretrained(self, save_directory, **kwargs):
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
qformer_tokenizer_path = os.path.join(save_directory, "qformer_tokenizer")
|
||||
self.qformer_tokenizer.save_pretrained(qformer_tokenizer_path)
|
||||
return super().save_pretrained(save_directory, **kwargs)
|
||||
|
||||
# overwrite to load the Q-Former tokenizer from a separate folder
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer")
|
||||
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
args.append(qformer_tokenizer)
|
||||
return cls(*args)
|
@ -4755,6 +4755,34 @@ class InstructBlipVisionModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class InstructBlipVideoForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class InstructBlipVideoPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class InstructBlipVideoVisionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class JambaForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -303,6 +303,13 @@ class ImageGPTImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class InstructBlipVideoImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class LayoutLMv2FeatureExtractor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
0
tests/models/instructblipvideo/__init__.py
Normal file
0
tests/models/instructblipvideo/__init__.py
Normal file
@ -0,0 +1,191 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import InstructBlipVideoImageProcessor
|
||||
|
||||
|
||||
class InstructBlipVideoProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_channels=3,
|
||||
image_size=24,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_normalize=True,
|
||||
image_mean=OPENAI_CLIP_MEAN,
|
||||
image_std=OPENAI_CLIP_STD,
|
||||
do_convert_rgb=True,
|
||||
frames=4,
|
||||
):
|
||||
size = size if size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.frames = frames
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.frames, self.num_channels, self.size["height"], self.size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
images = prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
# let's simply copy the frames to fake a long video-clip
|
||||
if numpify or torchify:
|
||||
videos = []
|
||||
for image in images:
|
||||
if numpify:
|
||||
video = image[None, ...].repeat(self.frames, 0)
|
||||
else:
|
||||
video = image[None, ...].repeat(self.frames, 1, 1, 1)
|
||||
videos.append(video)
|
||||
else:
|
||||
videos = []
|
||||
for pil_image in images:
|
||||
videos.append([pil_image] * self.frames)
|
||||
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class InstructBlipVideoProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = InstructBlipVideoImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = InstructBlipVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
|
||||
def test_call_pil(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video[0], Image.Image)
|
||||
|
||||
# Test not batched input (pass as `videos` arg to test that ImageProcessor can handle videos in absence of images!)
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (1, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (5, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
video_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, np.ndarray)
|
||||
|
||||
# Test not batched input (pass as `videos` arg to test that ImageProcessor can handle videos in absence of images!)
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (1, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (5, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
video_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(images=video_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (1, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(images=video_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_video_shape = (5, 4, 3, 18, 18)
|
||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
@ -0,0 +1,585 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch InstructBlipVideo model."""
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
InstructBlipVideoConfig,
|
||||
InstructBlipVideoProcessor,
|
||||
InstructBlipVideoQFormerConfig,
|
||||
InstructBlipVideoVisionConfig,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
random_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import InstructBlipVideoForConditionalGeneration, InstructBlipVideoVisionModel
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
pass
|
||||
|
||||
|
||||
class InstructBlipVideoVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
image_size=30,
|
||||
frames=4,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
initializer_range=1e-10,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.frames = frames
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
# in case of a vision transformer, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
[self.batch_size * self.frames, self.num_channels, self.image_size, self.image_size]
|
||||
)
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return InstructBlipVideoVisionConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = InstructBlipVideoVisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = (self.image_size, self.image_size)
|
||||
patch_size = (self.patch_size, self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size * self.frames, num_patches + 1, self.hidden_size)
|
||||
)
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size * self.frames, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class InstructBlipVideoVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some of the tests of test_modeling_common.py, as InstructBlipVideo's vision encoder does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (InstructBlipVideoVisionModel,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = InstructBlipVideoVisionModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=InstructBlipVideoVisionConfig, has_text_modality=False, hidden_size=37
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip(reason="InstructBlipVideo's vision encoder does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="InstructBlipVideo's vision encoder is an nn.Embeddings layer")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(
|
||||
reason="InstructBlipVideoVisionModel is an internal building block, doesn't support standalone training"
|
||||
)
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="InstructBlipVideoVisionModel is an internal building block, doesn't support standalone training"
|
||||
)
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="InstructBlipVideoVisionModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="InstructBlipVideoVisionModel has no base class and is not available in MODEL_MAPPING")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "Salesforce/instructblip-vicuna-7b"
|
||||
model = InstructBlipVideoVisionModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class InstructBlipVideoQFormerModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
bos_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.projection_dim = projection_dim
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
qformer_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
qformer_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
if input_mask is not None:
|
||||
batch_size, seq_length = input_mask.shape
|
||||
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
||||
for batch_idx, start_index in enumerate(rnd_start_indices):
|
||||
input_mask[batch_idx, :start_index] = 1
|
||||
input_mask[batch_idx, start_index:] = 0
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask, qformer_input_ids, qformer_attention_mask
|
||||
|
||||
def get_config(self):
|
||||
return InstructBlipVideoQFormerConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
projection_dim=self.projection_dim,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
|
||||
# this class is based on `OPTModelTester` found in tests/models/opt/test_modeling_opt.py
|
||||
class InstructBlipVideoTextModelDecoderOnlyTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=4,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=100,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
embed_dim=16,
|
||||
num_labels=3,
|
||||
word_embed_proj_dim=16,
|
||||
type_sequence_label_size=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.embed_dim = embed_dim
|
||||
self.num_labels = num_labels
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.word_embed_proj_dim = word_embed_proj_dim
|
||||
self.is_encoder_decoder = False
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
config = self.get_config()
|
||||
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(3)
|
||||
input_ids[:, -1] = self.eos_token_id # Eos Token
|
||||
|
||||
attention_mask = input_ids.ne(self.pad_token_id)
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
def get_config(self):
|
||||
return CONFIG_MAPPING["opt"](
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
ffn_dim=self.intermediate_size,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
embed_dim=self.embed_dim,
|
||||
is_encoder_decoder=False,
|
||||
word_embed_proj_dim=self.word_embed_proj_dim,
|
||||
)
|
||||
|
||||
|
||||
# this model tester uses a decoder-only language model (OPT)
|
||||
class InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester:
|
||||
def __init__(
|
||||
self, parent, vision_kwargs=None, qformer_kwargs=None, text_kwargs=None, is_training=True, num_query_tokens=10
|
||||
):
|
||||
if vision_kwargs is None:
|
||||
vision_kwargs = {}
|
||||
if qformer_kwargs is None:
|
||||
qformer_kwargs = {}
|
||||
if text_kwargs is None:
|
||||
text_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.vision_model_tester = InstructBlipVideoVisionModelTester(parent, **vision_kwargs)
|
||||
self.qformer_model_tester = InstructBlipVideoQFormerModelTester(parent, **qformer_kwargs)
|
||||
self.text_model_tester = InstructBlipVideoTextModelDecoderOnlyTester(parent, **text_kwargs)
|
||||
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
||||
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
|
||||
self.is_training = is_training
|
||||
self.num_query_tokens = num_query_tokens
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
_, _, _, qformer_input_ids, qformer_attention_mask = self.qformer_model_tester.prepare_config_and_inputs()
|
||||
_, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
frames = self.vision_model_tester.frames
|
||||
_, c, h, w = pixel_values.shape
|
||||
pixel_values = pixel_values.reshape(-1, frames, c, h, w)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, qformer_input_ids, qformer_attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return InstructBlipVideoConfig.from_vision_qformer_text_configs(
|
||||
vision_config=self.vision_model_tester.get_config(),
|
||||
qformer_config=self.qformer_model_tester.get_config(),
|
||||
text_config=self.text_model_tester.get_config(),
|
||||
num_query_tokens=self.num_query_tokens,
|
||||
)
|
||||
|
||||
def create_and_check_for_conditional_generation(
|
||||
self, config, input_ids, attention_mask, qformer_input_ids, qformer_attention_mask, pixel_values
|
||||
):
|
||||
model = InstructBlipVideoForConditionalGeneration(config).to(torch_device).eval()
|
||||
with torch.no_grad():
|
||||
result = model(
|
||||
pixel_values,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
qformer_input_ids=qformer_input_ids,
|
||||
qformer_attention_mask=qformer_attention_mask,
|
||||
)
|
||||
|
||||
expected_seq_length = (
|
||||
self.num_query_tokens * self.vision_model_tester.frames
|
||||
) + self.text_model_tester.seq_length
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.vision_model_tester.batch_size, expected_seq_length, self.text_model_tester.vocab_size),
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, qformer_input_ids, qformer_attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"qformer_input_ids": qformer_input_ids,
|
||||
"qformer_attention_mask": qformer_attention_mask,
|
||||
"labels": input_ids,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
||||
ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
|
||||
):
|
||||
all_model_classes = (InstructBlipVideoForConditionalGeneration,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_attention_outputs = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester(self)
|
||||
|
||||
def test_for_conditional_generation(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_conditional_generation(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="InstructBlipVideoForConditionalGeneration doesn't support inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Tied weights are tested in individual model tests")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="InstructBlipVideoModel does not have input/output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="There's no base InstructBlipVideoModel")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="There's no base InstructBlipVideoModel")
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_load_vision_qformer_text_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Save InstructBlipVideoConfig and check if we can load InstructBlipVideoVisionConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
vision_config = InstructBlipVideoVisionConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
||||
|
||||
# Save InstructBlipVideoConfig and check if we can load InstructBlipVideoQFormerConfig from it
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
config.save_pretrained(tmp_dir_name)
|
||||
qformer_config = InstructBlipVideoQFormerConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.qformer_config.to_dict(), qformer_config.to_dict())
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "Salesforce/instructblip-vicuna-7b"
|
||||
model = InstructBlipVideoForConditionalGeneration.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_video():
|
||||
video_file = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||
)
|
||||
video = np.load(video_file)[::2] # sample every 2nd frame to get 4 frames total
|
||||
return video
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@slow
|
||||
class InstructBlipVideoModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_vicuna_7b(self):
|
||||
processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
|
||||
model = InstructBlipVideoForConditionalGeneration.from_pretrained(
|
||||
"Salesforce/instructblip-vicuna-7b", load_in_8bit=True, low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
clip = prepare_video()
|
||||
prompt = "Explain what is happening in this short video."
|
||||
inputs = processor(images=clip, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
|
||||
# verify generation
|
||||
outputs = model.generate(**inputs, max_new_tokens=30)
|
||||
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
||||
self.assertEqual(
|
||||
generated_text,
|
||||
"a baby girl wearing glasses is reading a book on the bed 1080p",
|
||||
)
|
@ -90,6 +90,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"RecurrentGemmaModel", # Building part of bigger (tested) model.
|
||||
"FuyuForCausalLM", # Not tested fort now
|
||||
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
|
||||
"InstructBlipVideoQFormerModel", # Building part of bigger (tested) model.
|
||||
"UMT5EncoderModel", # Building part of bigger (tested) model.
|
||||
"Blip2QFormerModel", # Building part of bigger (tested) model.
|
||||
"ErnieMForInformationExtraction",
|
||||
@ -245,6 +246,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"GPTSw3DoubleHeadsModel",
|
||||
"InstructBlipVisionModel",
|
||||
"InstructBlipQFormerModel",
|
||||
"InstructBlipVideoVisionModel",
|
||||
"InstructBlipVideoQFormerModel",
|
||||
"LayoutLMForQuestionAnswering",
|
||||
"LukeForMaskedLM",
|
||||
"LukeForEntityClassification",
|
||||
|
@ -173,7 +173,7 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
- LLaMa -> MyNewModel abd MyNewModel -> Llama
|
||||
"""
|
||||
|
||||
def __init__(self, old_name, new_name):
|
||||
def __init__(self, old_name, new_name, given_old_name=None, given_new_name=None):
|
||||
super().__init__()
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
@ -183,6 +183,8 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
old_name.upper(): new_name.upper(),
|
||||
"".join(x.title() for x in old_name.split("_")): self.default_name,
|
||||
}
|
||||
if given_old_name is not None and given_new_name is not None and given_old_name not in self.patterns:
|
||||
self.patterns[given_old_name] = given_new_name
|
||||
|
||||
def preserve_case_replace(self, text):
|
||||
# Create a regex pattern to match all variations
|
||||
@ -201,9 +203,9 @@ class ReplaceNameTransformer(m.MatcherDecoratableTransformer):
|
||||
return updated_node.with_changes(value=update)
|
||||
|
||||
|
||||
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma"):
|
||||
def find_classes_in_file(module: cst.Module, old_id="llama", new_id="gemma", given_old_name=None, given_new_name=None):
|
||||
"""Helper function to rename and then parse a source file using the ClassFinder"""
|
||||
transformer = ReplaceNameTransformer(old_id, new_id)
|
||||
transformer = ReplaceNameTransformer(old_id, new_id, given_old_name, given_new_name)
|
||||
new_module = module.visit(transformer)
|
||||
|
||||
wrapper = MetadataWrapper(new_module)
|
||||
@ -356,11 +358,13 @@ def replace_call_to_super(class_finder: ClassFinder, updated_node: cst.ClassDef,
|
||||
class DiffConverterTransformer(CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider)
|
||||
|
||||
def __init__(self, python_module, new_name):
|
||||
def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None):
|
||||
super().__init__()
|
||||
self.model_name = (
|
||||
new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3`
|
||||
)
|
||||
self.given_old_name = given_old_name
|
||||
self.given_new_name = given_new_name
|
||||
# fmt: off
|
||||
self.python_module = python_module # we store the original module to use `code_for_node`
|
||||
self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module
|
||||
@ -426,6 +430,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
"insert_idx": self.global_scope_index,
|
||||
"node": updated_node,
|
||||
}
|
||||
self.config_body = [updated_node]
|
||||
return updated_node
|
||||
|
||||
def leave_ClassDef(self, original_node, updated_node):
|
||||
@ -457,13 +462,18 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name"
|
||||
)
|
||||
|
||||
if super_file_name not in self.visited_module: # only extract classes once
|
||||
visited_module = self.visited_module
|
||||
if super_file_name not in visited_module: # only extract classes once
|
||||
class_finder = find_classes_in_file(
|
||||
self.transformers_imports[super_file_name], model_name, self.model_name
|
||||
self.transformers_imports[super_file_name],
|
||||
model_name,
|
||||
self.model_name,
|
||||
self.given_old_name,
|
||||
self.given_new_name,
|
||||
)
|
||||
self.visited_module[super_file_name] = class_finder
|
||||
visited_module[super_file_name] = class_finder
|
||||
else: # we are re-using the previously parsed data
|
||||
class_finder = self.visited_module[super_file_name]
|
||||
class_finder = visited_module[super_file_name]
|
||||
|
||||
list_dependencies = {
|
||||
dep: class_finder.class_start_line.get(dep, 1000)
|
||||
@ -474,7 +484,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
start_insert_idx = self.global_scope_index
|
||||
for dependency, _ in list_dependencies:
|
||||
node = class_finder.global_nodes.get(dependency, None)
|
||||
if node is not None:
|
||||
if node is not None and "Config" not in class_name:
|
||||
if dependency not in self.new_body:
|
||||
start_insert_idx -= 1
|
||||
self.new_body[dependency] = {"insert_idx": start_insert_idx, "node": node}
|
||||
@ -485,7 +495,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
if len(list_dependencies) > 0:
|
||||
updated_node = replace_call_to_super(class_finder, updated_node, class_name)
|
||||
if "Config" in class_name:
|
||||
self.config_body = [updated_node]
|
||||
self.config_body += [updated_node]
|
||||
else:
|
||||
self.new_body[class_name] = {"insert_idx": self.global_scope_index, "node": updated_node}
|
||||
return updated_node
|
||||
@ -503,10 +513,24 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
def leave_Module(self, original_node: cst.Assign, node):
|
||||
imports = {self.python_module.code_for_node(k): k for k in self.all_imports}
|
||||
dependency_imports = {}
|
||||
config_imports = []
|
||||
for visiter in self.visited_module.values():
|
||||
dependency_imports.update({self.python_module.code_for_node(k): k for k in visiter.imports.values()})
|
||||
|
||||
# manually clean up if it's importing a config from configuration file (ruff doesn't do that)
|
||||
config_imports = []
|
||||
for i in list(dependency_imports.values()):
|
||||
if (
|
||||
hasattr(i.body[0], "module")
|
||||
and isinstance(i.body[0].module, cst.Name)
|
||||
and f"configuration_{self.model_name}" in i.body[0].module.value
|
||||
):
|
||||
pass
|
||||
else:
|
||||
config_imports.append(i)
|
||||
|
||||
if hasattr(self, "config_body"):
|
||||
self.config_body = list(imports.values()) + self.config_body
|
||||
self.config_body = list(imports.values()) + config_imports + self.config_body
|
||||
dependency_imports.update(imports)
|
||||
new_body = list(dependency_imports.values())
|
||||
if len(self.new_body.keys()) > 0:
|
||||
@ -516,7 +540,7 @@ class DiffConverterTransformer(CSTTransformer):
|
||||
return node.with_changes(body=[*new_body])
|
||||
|
||||
|
||||
def convert_file(diff_file, cst_transformers=None):
|
||||
def convert_file(diff_file, old_model_name=None, new_model_name=None, cst_transformers=None):
|
||||
model_name = re.search(r"diff_(.*)(?=\.py$)", diff_file).groups()[0]
|
||||
# Parse the Python file
|
||||
with open(diff_file, "r") as file:
|
||||
@ -524,7 +548,7 @@ def convert_file(diff_file, cst_transformers=None):
|
||||
module = cst.parse_module(code)
|
||||
wrapper = MetadataWrapper(module)
|
||||
if cst_transformers is None:
|
||||
cst_transformers = DiffConverterTransformer(module, model_name)
|
||||
cst_transformers = DiffConverterTransformer(module, model_name, old_model_name, new_model_name)
|
||||
new_mod = wrapper.visit(cst_transformers)
|
||||
ruffed_code = run_ruff(new_mod.code, True)
|
||||
formatted_code = run_ruff(ruffed_code, False)
|
||||
@ -551,10 +575,20 @@ if __name__ == "__main__":
|
||||
nargs="+",
|
||||
help="A list of `diff_xxxx` files that should be converted to single model file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old_model_name",
|
||||
required=False,
|
||||
help="The name of the model from which the copying is done in CamelCase. If not provided is inferred from diff-file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--new_model_name",
|
||||
required=False,
|
||||
help="The name of the new model being added in CamelCase. If not provided is inferred from diff-file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files_to_parse == ["all"]:
|
||||
args.files_to_parse = glob.glob("src/transformers/models/**/diff_*.py", recursive=True)
|
||||
for file_name in args.files_to_parse:
|
||||
print(f"Converting {file_name} to a single model single file format")
|
||||
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "")
|
||||
converter = convert_file(file_name)
|
||||
converter = convert_file(file_name, args.old_model_name, args.new_model_name)
|
||||
|
Loading…
Reference in New Issue
Block a user