mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
VLMs: patch_size
-> num_image_tokens
in processing (#33424)
* use num additional tokens * fix copies + docs * another fix copies :) * add docs * move order for BC
This commit is contained in:
parent
3ee24e2208
commit
1646ffb4d1
@ -40,6 +40,10 @@ The original code can be found [here](https://github.com/salesforce/LAVIS/tree/5
|
||||
- BLIP-2 can be used for conditional text generation given an image and an optional text prompt. At inference time, it's recommended to use the [`generate`] method.
|
||||
- One can use [`Blip2Processor`] to prepare images for the model, and decode the predicted tokens ID's back to text.
|
||||
|
||||
> [!NOTE]
|
||||
> BLIP models after release v4.46 will raise warnings about adding `processor.num_query_tokens = {{num_query_tokens}}` and expand model embeddings layer to add special `<image>` token. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. Adding these attributes means that BLIP will add the number of query tokens required per image and expand the text with as many `<image>` placeholders as there will be query tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there wil be failure when merging the embeddings.
|
||||
The attributes can be obtained from model config, as `model.config.num_query_tokens` and model embeddings expansion can be done by following [this link](https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042).
|
||||
|
||||
## Resources
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BLIP-2.
|
||||
|
@ -33,6 +33,10 @@ The original code can be found [here](https://github.com/salesforce/LAVIS/tree/m
|
||||
|
||||
InstructBLIP uses the same architecture as [BLIP-2](blip2) with a tiny but important difference: it also feeds the text prompt (instruction) to the Q-Former.
|
||||
|
||||
> [!NOTE]
|
||||
> BLIP models after release v4.46 will raise warnings about adding `processor.num_query_tokens = {{num_query_tokens}}` and expand model embeddings layer to add special `<image>` token. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. Adding these attributes means that BLIP will add the number of query tokens required per image and expand the text with as many `<image>` placeholders as there will be query tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there wil be failure when merging the embeddings.
|
||||
The attributes can be obtained from model config, as `model.config.num_query_tokens` and model embeddings expansion can be done by following [this link](https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042).
|
||||
|
||||
## InstructBlipConfig
|
||||
|
||||
[[autodoc]] InstructBlipConfig
|
||||
|
@ -35,6 +35,10 @@ The original code can be found [here](https://github.com/salesforce/LAVIS/tree/m
|
||||
|
||||
- The model was trained by sampling 4 frames per video, so it's recommended to sample 4 frames
|
||||
|
||||
> [!NOTE]
|
||||
> BLIP models after release v4.46 will raise warnings about adding `processor.num_query_tokens = {{num_query_tokens}}` and expand model embeddings layer to add special `<image>` token. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you. Adding these attributes means that BLIP will add the number of query tokens required per image and expand the text with as many `<image>` placeholders as there will be query tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there wil be failure when merging the embeddings.
|
||||
The attributes can be obtained from model config, as `model.config.num_query_tokens` and model embeddings expansion can be done by following [this link](https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042).
|
||||
|
||||
## InstructBlipVideoConfig
|
||||
|
||||
[[autodoc]] InstructBlipVideoConfig
|
||||
|
@ -40,6 +40,13 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
|
||||
|
||||
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
|
||||
|
||||
|
||||
> [!NOTE]
|
||||
> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you.
|
||||
Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `<image>` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings.
|
||||
The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches.
|
||||
|
||||
|
||||
### Single image inference
|
||||
|
||||
For best results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows:
|
||||
|
@ -53,6 +53,12 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
|
||||
</Tip>
|
||||
|
||||
|
||||
> [!NOTE]
|
||||
> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you.
|
||||
Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `<image>` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings.
|
||||
The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches.
|
||||
|
||||
|
||||
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint.
|
||||
|
||||
We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows:
|
||||
|
@ -50,6 +50,12 @@ The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tre
|
||||
</Tip>
|
||||
|
||||
|
||||
> [!NOTE]
|
||||
> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you.
|
||||
Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `<image>` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings.
|
||||
The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches.
|
||||
|
||||
|
||||
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that.
|
||||
|
||||
We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows:
|
||||
|
@ -54,6 +54,12 @@ This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanT
|
||||
The original code can be found [here](https://github.com/PKU-YuanGroup/Video-LLaVA).
|
||||
|
||||
|
||||
> [!NOTE]
|
||||
> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you.
|
||||
Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `<image>` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings.
|
||||
The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches.
|
||||
|
||||
|
||||
## Usage example
|
||||
|
||||
### Single Media Mode
|
||||
|
@ -39,6 +39,12 @@ This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada)
|
||||
|
||||
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
|
||||
|
||||
> [!NOTE]
|
||||
> LLaVA models after release v4.46 will raise warnings about adding `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. It is strongly recommended to add the attributes to the processor if you own the model checkpoint, or open a PR if it is not owned by you.
|
||||
Adding these attributes means that LLaVA will try to infer the number of image tokens required per image and expand the text with as many `<image>` placeholders as there will be tokens. Usually it is around 500 tokens per image, so make sure that the text is not truncated as otherwise there will be failure when merging the embeddings.
|
||||
The attributes can be obtained from model config, as `model.config.vision_config.patch_size` or `model.config.vision_feature_select_strategy`. The `num_additional_image_tokens` should be `1` if the vision backbone adds a CLS token or `0` if nothing extra is added to the vision patches.
|
||||
|
||||
|
||||
- For better results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows:
|
||||
|
||||
```python
|
||||
|
@ -58,10 +58,19 @@ class LlavaProcessor(ProcessorMixin):
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
num_additional_image_tokens (`int`, *optional*, defaults to 0):
|
||||
Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
|
||||
extra tokens appended, no need to set this arg.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"patch_size",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"num_additional_image_tokens",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
@ -73,9 +82,11 @@ class LlavaProcessor(ProcessorMixin):
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
num_additional_image_tokens=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.num_additional_image_tokens = num_additional_image_tokens
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
@ -147,9 +158,11 @@ class LlavaProcessor(ProcessorMixin):
|
||||
# Replace the image token with the expanded image token sequence
|
||||
pixel_values = image_inputs["pixel_values"]
|
||||
height, width = get_image_size(to_numpy_array(pixel_values[0]))
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
|
||||
num_image_tokens = (height // self.patch_size) * (
|
||||
width // self.patch_size
|
||||
) + self.num_additional_image_tokens
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
num_image_tokens -= self.num_additional_image_tokens
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
|
@ -61,10 +61,19 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
in a chat into a tokenizable string.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
num_additional_image_tokens (`int`, *optional*, defaults to 0):
|
||||
Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
|
||||
extra tokens appended, no need to set this arg.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"patch_size",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"num_additional_image_tokens",
|
||||
]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
@ -76,9 +85,11 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
vision_feature_select_strategy=None,
|
||||
chat_template=None,
|
||||
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
num_additional_image_tokens=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.num_additional_image_tokens = num_additional_image_tokens
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
@ -155,7 +166,7 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
orig_height, orig_width = image_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
num_image_tokens -= self.num_additional_image_tokens
|
||||
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
|
||||
prompt_strings.append(sample)
|
||||
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
||||
@ -178,7 +189,7 @@ class LlavaNextProcessor(ProcessorMixin):
|
||||
orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
|
||||
)
|
||||
# The base patch covers the entire image (+1 for the CLS)
|
||||
base_features = patches_height * patches_width + 1
|
||||
base_features = patches_height * patches_width + self.num_additional_image_tokens
|
||||
num_image_tokens = unpadded_features + newline_features + base_features
|
||||
return num_image_tokens
|
||||
|
||||
|
@ -58,12 +58,22 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
Special token used to denote video location.
|
||||
image_token (`str`, *optional*, defaults to `"<image>"`):
|
||||
Special token used to denote image location.
|
||||
num_additional_image_tokens (`int`, *optional*, defaults to 0):
|
||||
Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
|
||||
extra tokens appended, no need to set this arg.
|
||||
"""
|
||||
|
||||
# video and image processor share same args, but have different processing logic
|
||||
# only image processor config is saved in the hub
|
||||
attributes = ["video_processor", "image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token", "video_token"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"patch_size",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"num_additional_image_tokens",
|
||||
]
|
||||
image_processor_class = "LlavaNextImageProcessor"
|
||||
video_processor_class = "LlavaNextVideoImageProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
@ -78,9 +88,11 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
vision_feature_select_strategy=None,
|
||||
video_token="<video>",
|
||||
image_token="<image>",
|
||||
num_additional_image_tokens=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.num_additional_image_tokens = num_additional_image_tokens
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
@ -164,8 +176,9 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
if self.patch_size is None or self.vision_feature_select_strategy is None:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. "
|
||||
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Please add `patch_size`, `num_additional_image_tokens` and `vision_feature_select_strategy` to the model's processing config or set directly "
|
||||
"with `processor.patch_size = {{patch_size}}`, `processor.num_additional_image_tokens = {{num_additional_image_tokens}}` "
|
||||
"and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
|
||||
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
|
||||
)
|
||||
else:
|
||||
@ -180,7 +193,7 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
orig_height, orig_width = image_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
num_image_tokens -= self.num_additional_image_tokens
|
||||
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
|
||||
prompt_strings.append(sample)
|
||||
text = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
||||
@ -190,6 +203,8 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
one_video = to_numpy_array(videos_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
# no `self.num_additional_image_tokens` added because video always has a default feature selection strategy
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size)
|
||||
num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
prompt_strings = []
|
||||
@ -222,7 +237,7 @@ class LlavaNextVideoProcessor(ProcessorMixin):
|
||||
orig_height, orig_width, patches_height, patches_width, scale_height, scale_width
|
||||
)
|
||||
# The base patch covers the entire image (+1 for the CLS)
|
||||
base_features = patches_height * patches_width + 1
|
||||
base_features = patches_height * patches_width + self.num_additional_image_tokens
|
||||
num_image_tokens = unpadded_features + newline_features + base_features
|
||||
return num_image_tokens
|
||||
|
||||
|
@ -51,10 +51,20 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
Special token used to denote video location.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
num_additional_image_tokens (`int`, *optional*, defaults to 0):
|
||||
Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
|
||||
extra tokens appended, no need to set this arg.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template", "patch_size", "vision_feature_select_strategy", "image_token", "video_token"]
|
||||
valid_kwargs = [
|
||||
"chat_template",
|
||||
"patch_size",
|
||||
"vision_feature_select_strategy",
|
||||
"image_token",
|
||||
"video_token",
|
||||
"num_additional_image_tokens",
|
||||
]
|
||||
image_processor_class = "VideoLlavaImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
@ -67,9 +77,11 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
|
||||
video_token="<video>",
|
||||
chat_template=None,
|
||||
num_additional_image_tokens=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.num_additional_image_tokens = num_additional_image_tokens
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
|
||||
@ -165,13 +177,17 @@ class VideoLlavaProcessor(ProcessorMixin):
|
||||
height, width = get_image_size(one_video[0])
|
||||
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
|
||||
num_image_tokens = (height // self.patch_size) * (
|
||||
width // self.patch_size
|
||||
) + self.num_additional_image_tokens
|
||||
num_video_tokens = num_image_tokens * num_frames
|
||||
|
||||
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
|
||||
num_image_tokens = (height // self.patch_size) * (
|
||||
width // self.patch_size
|
||||
) + self.num_additional_image_tokens
|
||||
num_video_tokens = num_image_tokens * num_frames
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
num_image_tokens -= self.num_additional_image_tokens
|
||||
|
||||
prompt_strings = []
|
||||
for sample in text:
|
||||
|
@ -607,6 +607,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.num_additional_image_tokens = 1
|
||||
processor.patch_size = 14
|
||||
inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
|
||||
@ -614,6 +615,7 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 18)
|
||||
|
||||
|
@ -622,6 +622,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
processor.num_additional_image_tokens = 1
|
||||
inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
@ -630,6 +631,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
@ -656,12 +658,14 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
processor.num_additional_image_tokens = 1
|
||||
inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2356)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 17)
|
||||
|
||||
|
@ -558,12 +558,14 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
processor.num_additional_image_tokens = 1
|
||||
inputs_expanded = processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 1170)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 19)
|
||||
|
||||
@ -586,12 +588,14 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
processor.num_additional_image_tokens = 1
|
||||
inputs_expanded = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2652)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 19)
|
||||
|
||||
@ -624,6 +628,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
processor.num_additional_image_tokens = 1
|
||||
inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
@ -632,6 +637,7 @@ class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to(
|
||||
torch_device, torch.float16
|
||||
)
|
||||
|
@ -625,12 +625,14 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
processor.num_additional_image_tokens = 1
|
||||
inputs_expanded = processor(prompt, images=image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 274)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(prompt, images=image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 19)
|
||||
|
||||
@ -657,12 +659,14 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
processor.num_additional_image_tokens = 1
|
||||
inputs_expanded = processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2074)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 19)
|
||||
|
||||
|
@ -374,12 +374,14 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
# check processing with expansion of inputs
|
||||
processor.vision_feature_select_strategy = "default"
|
||||
processor.patch_size = 14
|
||||
processor.num_additional_image_tokens = 1
|
||||
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
|
||||
|
||||
# check processing without expansion of inputs (legacy behavior)
|
||||
processor.vision_feature_select_strategy = None
|
||||
processor.patch_size = None
|
||||
processor.num_additional_image_tokens = None
|
||||
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
|
||||
self.assertTrue(inputs.input_ids.shape[-1] == 18)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user