[transformers x vLLM] standardize processors (#37915)

* standardize

* fix tests

* batch update some processors, not final yet

* oke, now I tested that everything indeed runs. Still needs prettification

* emu3

* fixup

* gemma3 but it doesn't generate anything

* fuyu

* update

* why?

* Update src/transformers/models/aya_vision/processing_aya_vision.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* address comments

* bc

* why do we need to guard import this every time?

* i hate guarded imports

* i am blind

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Raushan Turganbay 2025-05-27 11:30:30 +02:00 committed by GitHub
parent b5ececb900
commit 9e1017b479
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1168 additions and 80 deletions

View File

@ -500,5 +500,26 @@ class AriaImageProcessor(BaseImageProcessor):
]
return patches
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
split_image = images_kwargs.get("split_image", None) or self.split_image
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
return num_patches
__all__ = ["AriaImageProcessor"]

View File

@ -34,7 +34,7 @@ from ...image_utils import (
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import PreTrainedModel
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils import PreTokenizedInput, TextInput
from ...utils import LossKwargs, TensorType, auto_docstring, can_return_tuple, logging
from ...utils.import_utils import is_torch_available
@ -884,11 +884,33 @@ class AriaImageProcessor(BaseImageProcessor):
]
return patches
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
split_image = images_kwargs.get("split_image", None) or self.split_image
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
return num_patches
class AriaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"max_image_size": 980,
@ -978,10 +1000,7 @@ class AriaProcessor(ProcessorMixin):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if images is not None:
image_inputs = self.image_processor(
images,
**output_kwargs["images_kwargs"],
)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
# expand the image_token according to the num_crops and tokens per image
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
prompt_strings = []
@ -995,11 +1014,44 @@ class AriaProcessor(ProcessorMixin):
prompt_strings = text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -20,9 +20,11 @@
# limitations under the License.
from typing import Dict, List, Optional, Union
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils import PreTokenizedInput, TextInput
from ...utils import TensorType
from ..auto import AutoTokenizer
@ -32,6 +34,7 @@ class AriaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"max_image_size": 980,
@ -121,10 +124,7 @@ class AriaProcessor(ProcessorMixin):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if images is not None:
image_inputs = self.image_processor(
images,
**output_kwargs["images_kwargs"],
)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
# expand the image_token according to the num_crops and tokens per image
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
prompt_strings = []
@ -138,11 +138,44 @@ class AriaProcessor(ProcessorMixin):
prompt_strings = text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from transformers.processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_utils import (
ImageInput,
make_flat_list_of_images,
)
from ...processing_utils import (
ImagesKwargs,
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
class AyaVisionImagesKwargs(ImagesKwargs, total=False):
@ -43,6 +44,7 @@ class AyaVisionProcessorKwargs(ProcessingKwargs, total=False):
"text_kwargs": {
"padding_side": "left",
"padding": True,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"crop_to_patches": True,
@ -121,7 +123,6 @@ class AyaVisionProcessor(ProcessorMixin):
super().__init__(image_processor, tokenizer, chat_template=chat_template)
self.image_token = image_token
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.patch_size = patch_size * downsample_factor
self.img_size = img_size
@ -131,6 +132,10 @@ class AyaVisionProcessor(ProcessorMixin):
self.img_line_break_token = img_line_break_token
self.tile_token = tile_token
self.tile_global_token = tile_global_token
self.image_token_id = tokenizer.convert_tokens_to_ids(self.img_patch_token)
self.image_ids = tokenizer.convert_tokens_to_ids(
[img_patch_token, tile_token, tile_global_token, start_of_img_token, end_of_img_token]
)
def _prompt_split_image(self, num_patches):
"""
@ -226,11 +231,49 @@ class AyaVisionProcessor(ProcessorMixin):
text = processed_text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = AyaVisionProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
token_per_patch = (self.img_size // self.patch_size) ** 2
num_image_tokens = [
token_per_patch + 3 + sum(token_per_patch + 1 for _ in range(1, num_patches))
for num_patches in num_image_patches
] # Add +3 and +1 for BOI/EOI and image tile tokens
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -18,9 +18,18 @@ Processor class for Chameleon.
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
@ -34,6 +43,7 @@ class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
"text_kwargs": {
"padding": False,
"return_for_text_completion": False,
"return_mm_token_type_ids": False,
},
"common_kwargs": {
"return_tensors": "pt",
@ -73,6 +83,10 @@ class ChameleonProcessor(ProcessorMixin):
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
) # fixed tokens for start and end, so can hardcode
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token)
self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id]
super().__init__(image_processor, tokenizer)
@ -141,14 +155,45 @@ class ChameleonProcessor(ProcessorMixin):
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
prompt_strings.append(sample)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(prompt_strings, data, modalities=["image"])
image_inputs = {}
if images is not None:
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
return BatchFeature(data=data, tensor_type=return_tensors)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
# add 2 for BOI and EOI tokens
num_image_tokens = [self.image_seq_length + 2] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):

View File

@ -24,7 +24,7 @@ from typing import ClassVar, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
from ...utils import is_torch_available
@ -256,6 +256,25 @@ class ColPaliProcessor(ProcessorMixin):
return batch_query
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = [self.image_seq_length] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -16,10 +16,17 @@
from typing import List, Optional, Union
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_vision_available
if is_vision_available():
from .image_processing_emu3 import smart_resize
class Emu3TextKwargs(TextKwargs, total=False):
@ -37,6 +44,7 @@ class Emu3ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"return_for_image_generation": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"ratio": "1:1",
@ -166,7 +174,7 @@ class Emu3Processor(ProcessorMixin):
image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'<placeholder>' * image_seq_length}{image_end_tokens}"
sample = sample.replace(self.image_token, image_placeholder, 1)
sample = f"{self.bos_token}{sample}" # add BOS because PT tokenizer doesn't add it
sample = f"{self.bos_token}{sample}" # add BOS because GPT tokenizer doesn't add it
prompt_strings.append(sample)
text = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
@ -179,12 +187,51 @@ class Emu3Processor(ProcessorMixin):
# else just generate from text-only input, and we do no special treatment for text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
data = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, data, modalities=["image"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
data.update(**image_features)
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data=data, tensor_type=return_tensors)
return BatchFeature(data={**text_inputs, **image_features}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = []
for height, width in image_sizes:
height, width = smart_resize(
height,
width,
self.image_processor.spatial_factor,
self.image_processor.min_pixels,
self.image_processor.max_pixels,
)
height = height // self.downsample_ratio
width = width // self.downsample_ratio
image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code
num_image_tokens.append(image_seq_length)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def calculate_generate_size(self, ratio, image_area, spatial_factor):
width, height = map(int, ratio.split(":"))

View File

@ -130,7 +130,7 @@ class FuyuModel(FuyuPreTrainedModel):
)
return output_embeddings
def get_image_features(self, pixel_values: torch.FloatTensor):
def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.

View File

@ -22,7 +22,13 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_available, logging, requires_backends
from ...utils.import_utils import requires
@ -64,6 +70,7 @@ class FuyuProcessorKwargs(ProcessingKwargs, total=False):
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
"return_mm_token_type_ids": False,
},
"images_kwargs": {},
}
@ -355,6 +362,8 @@ class FuyuProcessor(ProcessorMixin):
self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
self.pad_token_id = 0
self.dummy_image_index = -1
self.image_token_id = tokenizer.encode("|SPEAKER|", add_special_tokens=False)[1]
self.image_newline_id = tokenizer.encode("|NEWLINE|", add_special_tokens=False)[1]
def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool):
max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
@ -403,6 +412,11 @@ class FuyuProcessor(ProcessorMixin):
for key in batched_keys:
batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
# Cast images to tensor as well, if only one image passed and no padding needed
# NOTE: vLLM expects all processor outputs to be a tensor
if len(batched_inputs["image_patches"]) == 1:
batched_inputs["image_patches"] = torch.cat(batched_inputs["image_patches"], dim=0)
return batched_inputs
def get_sample_encoding(
@ -517,6 +531,7 @@ class FuyuProcessor(ProcessorMixin):
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
raise ValueError("`return_attention_mask=False` is not supported for this model.")
@ -550,8 +565,6 @@ class FuyuProcessor(ProcessorMixin):
# --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
# --- Use self.image_processor again to obtain the full token ids and batch inputs ---
@ -565,16 +578,63 @@ class FuyuProcessor(ProcessorMixin):
scale_factors=[scale_factor],
image_unpadded_heights=torch.tensor([image_unpadded_height]),
image_unpadded_widths=torch.tensor([image_unpadded_width]),
image_placeholder_id=image_placeholder_id,
image_newline_id=image_newline_id,
image_placeholder_id=self.image_token_id,
image_newline_id=self.image_newline_id,
tensor_batch_images=tensor_batch_image.unsqueeze(0),
)
all_encodings.append(sample_encoding)
batch_encoding = self._left_pad_inputs_with_attention_mask(
model_inputs=all_encodings, return_attention_mask=True
)
if return_mm_token_type_ids:
input_ids = batch_encoding["input_ids"]
mm_token_type_ids = torch.zeros_like(input_ids)
mm_token_type_ids[input_ids == self.image_token_id] = 1
mm_token_type_ids[input_ids == self.image_newline_id] = 1
batch_encoding["mm_token_type_ids"] = mm_token_type_ids
return FuyuBatchFeature(data=batch_encoding)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
size = kwargs.get("size", None) or self.image_processor.size
padded_height, padded_width = size["height"], size["width"]
num_image_tokens = []
num_image_patches = [1] * len(image_sizes)
for image_size in image_sizes:
height_scale_factor = padded_height / image_size[0]
width_scale_factor = padded_width / image_size[1]
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
# We can use torch here because Fuyu processor has hard dependency on torch
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
image_input=torch.zeros(1, 1, 3, padded_height, padded_width),
image_present=torch.ones(1, 1, 1),
image_unpadded_h=torch.tensor([[int(image_size[0] * optimal_scale_factor)]]),
image_unpadded_w=torch.tensor([[int(image_size[1] * optimal_scale_factor)]]),
image_placeholder_id=0, # dummy ids, we can be sure `id=0` is never out-of-range
image_newline_id=0,
variable_sized=True,
)
num_image_tokens.append(model_image_input["image_input_ids"][0][0].shape[-1])
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def post_process_box_coordinates(self, outputs, target_sizes=None):
"""
Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.

View File

@ -20,7 +20,7 @@ import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, make_nested_list_of_images
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import to_py_obj
@ -38,6 +38,7 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": True,
},
"images_kwargs": {
"do_pan_and_scan": False,
@ -137,17 +138,42 @@ class Gemma3Processor(ProcessorMixin):
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
# Add token type ids manually, as tokenizer can't do arbitrary position token types
array_ids = text_inputs["input_ids"]
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(array_ids)
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
# NOTE: no image cropping supported yet
num_image_tokens = [self.image_seq_length] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""

View File

@ -491,5 +491,33 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
return processed_images
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
min_patches = images_kwargs.get("min_patches", None) or self.min_patches
max_patches = images_kwargs.get("max_patches", None) or self.max_patches
patch_size = images_kwargs.get("size", None) or self.size
crop_to_patches = images_kwargs.get("crop_to_patches", None) or self.crop_to_patches
num_patches = 1
if crop_to_patches and max_patches > 1:
num_columns, num_rows = get_optimal_tiled_canvas(
(height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches
)
num_patches += num_columns * num_rows
return num_patches
__all__ = ["GotOcr2ImageProcessor"]

View File

@ -228,5 +228,33 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
)
def get_number_of_image_tokens(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
min_patches = images_kwargs.get("min_patches", None) or self.min_patches
max_patches = images_kwargs.get("max_patches", None) or self.max_patches
patch_size = images_kwargs.get("size", None) or self.size
crop_to_patches = images_kwargs.get("crop_to_patches", None) or self.crop_to_patches
num_patches = 1
if crop_to_patches and max_patches > 1:
num_columns, num_rows = get_optimal_tiled_canvas(
(height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches
)
num_patches += num_columns * num_rows
return num_patches
__all__ = ["GotOcr2ImageProcessorFast"]

View File

@ -850,5 +850,46 @@ class Idefics3ImageProcessor(BaseImageProcessor):
return encoding
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
do_image_splitting = images_kwargs.get("do_image_splitting", None) or self.do_image_splitting
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
size = images_kwargs.get("size", None) or self.size
if do_image_splitting:
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096)
aspect_ratio = width / height
if width >= height:
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
resized_height = int(width / aspect_ratio)
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
elif height > width:
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
resized_width = int(height * aspect_ratio)
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
max_height = max_width = max_image_size["longest_edge"]
if resized_height > max_height or resized_width > max_width:
# Calculate the number of splits
num_rows = math.ceil(resized_height / max_height)
num_cols = math.ceil(resized_width / max_width)
num_patches = num_rows * num_cols + 1
return num_patches
__all__ = ["Idefics3ImageProcessor"]

View File

@ -16,13 +16,16 @@
Processor class for Idefics3.
"""
import math
import re
from itertools import accumulate
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
from ...utils import logging
@ -98,6 +101,7 @@ class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
"add_special_tokens": True,
"padding": False,
"is_split_into_words": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"return_row_col_info": True,
@ -146,6 +150,12 @@ class Idefics3Processor(ProcessorMixin):
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True).content
self.global_image_tag = "<global-img>" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341
self.image_seq_len = image_seq_len
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.fake_image_token_id = tokenizer.convert_tokens_to_ids(self.fake_image_token)
self.global_image_token_id = tokenizer.convert_tokens_to_ids(self.global_image_tag)
self.row_col_ids = [
tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>") for i in range(6) for j in range(6)
]
# This regex matches one or more occurrences of <global-img> tags (optionally surrounded by newline characters)
# or <row_x_col_y> tags (where x and y are digits, also optionally surrounded by newline characters).
@ -241,6 +251,7 @@ class Idefics3Processor(ProcessorMixin):
)
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
n_images_in_text = []
@ -302,9 +313,11 @@ class Idefics3Processor(ProcessorMixin):
global_img_token = self.global_image_tag
prompt_strings = []
batch_image_seq_lengths = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
image_seq_lengths = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = get_image_prompt_string(
n_rows,
@ -314,8 +327,12 @@ class Idefics3Processor(ProcessorMixin):
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
# Add +2 and +3 for special BOI/EOI/fake_image_wrapper tokens
row_length = (self.image_seq_len + 2) * n_cols + 1
image_seq_lengths.append((self.image_seq_len + 3) + row_length * n_rows)
image_prompt_strings.append(image_prompt_string)
batch_image_seq_lengths.append(image_seq_lengths)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")
@ -338,7 +355,59 @@ class Idefics3Processor(ProcessorMixin):
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)
return BatchFeature(inputs, tensor_type=return_tensors)
if return_mm_token_type_ids:
array_ids = np.array(inputs["input_ids"])
mm_token_type_ids = np.zeros_like(array_ids)
for i, seq_lengths in enumerate(batch_image_seq_lengths):
image_start_positions = np.where(array_ids[i] == self.fake_image_token_id)[0]
j = 0
for seq_len in seq_lengths:
if j >= len(image_start_positions):
break
start = image_start_positions[j]
end = start + seq_len
mm_token_type_ids[i, start:end] = 1
j = np.searchsorted(image_start_positions, end)
inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data=inputs, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = Idefics3ProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
base_image_length = self.image_seq_len + 3
col_length = self.image_seq_len + 2
num_image_tokens = []
for num_patches in num_image_patches:
num_cols = num_rows = int(math.sqrt(num_patches - 1))
row_length = col_length * num_cols + 1
num_image_tokens.append(base_image_length + (row_length * num_rows))
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""

View File

@ -13,25 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import numpy as np
from transformers.processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from ...image_processing_utils import BatchFeature
from ...image_utils import (
ImageInput,
concatenate_list,
make_flat_list_of_images,
)
from ...processing_utils import (
ImagesKwargs,
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos
@ -46,6 +45,7 @@ class InternVLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding_side": "left",
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"crop_to_patches": True,
@ -94,9 +94,12 @@ class InternVLProcessor(ProcessorMixin):
self.image_seq_length = image_seq_length
self.start_image_token = tokenizer.start_image_token
self.end_image_token = tokenizer.end_image_token
self.start_image_token_id = tokenizer.start_image_token_id
self.end_image_token_id = tokenizer.end_image_token_id
self.image_token = tokenizer.context_image_token
self.video_token = tokenizer.video_token
self.image_token_id = tokenizer.context_image_token_id
self.image_ids = [self.image_token_id, self.start_image_token_id, self.end_image_token_id]
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
@ -261,11 +264,46 @@ class InternVLProcessor(ProcessorMixin):
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = InternVLProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
num_image_patches = [
self.image_processor.get_number_of_image_tokens(*image_size, images_kwargs)
for image_size in image_sizes
]
# Add 2 for BOI and EOI tokens
num_image_tokens = [2 + (self.image_seq_length * num_patches) for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def sample_indices_fn(
self, metadata: VideoMetadata, num_frames: Optional[int] = None, initial_shift: Union[bool, float, int] = True
):

View File

@ -18,9 +18,17 @@ Processor class for Llava.
from typing import List, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
@ -30,9 +38,7 @@ logger = logging.get_logger(__name__)
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"text_kwargs": {"padding": False, "return_mm_token_type_ids": False},
"images_kwargs": {},
}
@ -89,11 +95,7 @@ class LlavaProcessor(ProcessorMixin):
self.num_additional_image_tokens = num_additional_image_tokens
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.image_token_id = tokenizer.encode(self.image_token, add_special_tokens=False)[0]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
@ -174,10 +176,49 @@ class LlavaProcessor(ProcessorMixin):
prompt_strings.append(sample)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = LlavaProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
crop_size = images_kwargs.get("crop_size", None) or self.image_processor.crop_size
resized_height, resized_width = crop_size["height"], crop_size["width"]
num_image_tokens = (resized_height // self.patch_size) * (resized_width // self.patch_size)
num_image_tokens += self.num_additional_image_tokens
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
num_image_tokens = [num_image_tokens] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -18,10 +18,18 @@ Processor class for LLaVa-NeXT.
from typing import List, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
@ -33,6 +41,7 @@ class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"do_pad": True,
@ -172,9 +181,16 @@ class LlavaNextProcessor(ProcessorMixin):
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
@ -219,6 +235,48 @@ class LlavaNextProcessor(ProcessorMixin):
newline_features = current_height
return (unpadded_features, newline_features)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (List[List[str]], *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
audio_lengths (List[int], *optional*):
The input length formatted as per each audio.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = LlavaNextProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
size = images_kwargs.get("size", None) or self.image_processor.size
size = (
(size["shortest_edge"], size["shortest_edge"])
if "shortest_edge" in size
else (min(size["height"], size["width"]), min(size["height"], size["width"]))
)
processed_height, processed_width = size
batch_num_image_tokens = []
num_image_patches = [1] * len(image_sizes) # llava-next doesn't batch pixels as Idefics, thus `1` patch`
for image_size in image_sizes:
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(
orig_height, orig_width, processed_height, processed_width
)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
batch_num_image_tokens.append(num_image_tokens)
vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -24,7 +24,7 @@ import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...video_utils import VideoInput
@ -38,6 +38,7 @@ class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"image_kwargs": {},
"videos_kwargs": {},
@ -196,9 +197,16 @@ class LlavaOnevisionProcessor(ProcessorMixin):
text = [sample.replace(self.video_token, self.video_token * num_video_tokens) for sample in text]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs}, tensor_type=return_tensors)
def _expand_image_tokens(
@ -285,6 +293,48 @@ class LlavaOnevisionProcessor(ProcessorMixin):
return (unpadded_features, newline_features)
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (List[List[str]], *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
audio_lengths (List[int], *optional*):
The input length formatted as per each audio.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = LlavaOnevisionProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
size = images_kwargs.get("size", None) or self.image_processor.size
size = (
(size["shortest_edge"], size["shortest_edge"])
if "shortest_edge" in size
else (min(size["height"], size["width"]), min(size["height"], size["width"]))
)
processed_height, processed_width = size
batch_num_image_tokens = []
num_image_patches = [1] * len(image_sizes) # llava-ov doesn't batch pixels as Idefics, thus `1` patch`
for image_size in image_sizes:
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(
orig_height, orig_width, processed_height, processed_width
)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
batch_num_image_tokens.append(num_image_tokens)
vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -18,10 +18,13 @@ Processor class for PaliGemma.
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
from ...processing_utils import (
ImagesKwargs,
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
@ -56,6 +59,7 @@ class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"data_format": "channels_first",
@ -299,6 +303,7 @@ class PaliGemmaProcessor(ProcessorMixin):
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
inputs = self.tokenizer(
input_strings,
text_pair=suffix,
@ -312,8 +317,34 @@ class PaliGemmaProcessor(ProcessorMixin):
if return_token_type_ids:
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
return_data.update({"labels": labels})
if return_mm_token_type_ids:
array_ids = np.array(return_data["input_ids"])
mm_token_type_ids = np.zeros_like(return_data["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
return_data["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data=return_data, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = [self.image_seq_length] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""

View File

@ -18,11 +18,23 @@ Processor class for Pixtral.
from typing import List, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...utils import is_vision_available, logging
if is_vision_available():
from .image_processing_pixtral import get_resize_output_image_size
logger = logging.get_logger(__name__)
@ -32,6 +44,7 @@ class PixtralProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {},
"common_kwargs": {
@ -106,6 +119,10 @@ class PixtralProcessor(ProcessorMixin):
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.image_break_token = image_break_token
self.image_end_token = image_end_token
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.image_break_token_id = tokenizer.convert_tokens_to_ids(self.image_break_token)
self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
self.image_ids = [self.image_token_id, self.image_break_token_id, self.image_end_token_id]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
@ -213,10 +230,54 @@ class PixtralProcessor(ProcessorMixin):
prompt_strings.append(sample)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = PixtralProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
size = images_kwargs.get("size", None) or self.image_processor.size
patch_size = self.patch_size * self.spatial_merge_size
num_image_tokens = []
for height, width in image_sizes:
resized_height, resized_width = get_resize_output_image_size(
image=np.zeros((height, width, 3)),
size=(size["longest_edge"], size["longest_edge"]),
patch_size=(patch_size, patch_size),
)
num_height_tokens = resized_height // patch_size
num_width_tokens = resized_width // patch_size
num_image_tokens.append((num_width_tokens + 1) * num_height_tokens)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -22,6 +22,7 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -48,7 +49,7 @@ from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
from ...processing_utils import MultiModalData, ProcessingKwargs, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torchdynamo_compiling, logging
from ...video_utils import VideoInput
@ -925,6 +926,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"fps": 2.0},
}
@ -1050,11 +1052,56 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
if video_sizes is not None:
videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
videos_kwargs.update(kwargs)
num_video_patches = [
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
for video_size in video_sizes
]
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
vision_data["num_video_tokens"] = num_video_tokens
return MultiModalData(**vision_data)
__all__ = [
"Qwen2_5_VLConfig",

View File

@ -25,9 +25,11 @@
# limitations under the License.
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...video_utils import VideoInput
@ -50,6 +52,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"fps": 2.0},
}
@ -188,11 +191,56 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
if video_sizes is not None:
videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
videos_kwargs.update(kwargs)
num_video_patches = [
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
for video_size in video_sizes
]
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
vision_data["num_video_tokens"] = num_video_tokens
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -490,5 +490,31 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return BatchFeature(data=data, tensor_type=return_tensors)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of image patches per image.
"""
min_pixels = images_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
max_pixels = images_kwargs.get("max_pixels", None) or self.size["longest_edge"]
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * grid_w
__all__ = ["Qwen2VLImageProcessor"]

View File

@ -402,5 +402,31 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
return BatchFeature(data=data, tensor_type=return_tensors)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of image patches per image.
"""
min_pixels = images_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
max_pixels = images_kwargs.get("max_pixels", None) or self.size["longest_edge"]
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * grid_w
__all__ = ["Qwen2VLImageProcessorFast"]

View File

@ -23,9 +23,11 @@ Processor class for Qwen2-VL.
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...video_utils import VideoInput
@ -47,6 +49,7 @@ class Qwen2VLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
}
@ -172,10 +175,56 @@ class Qwen2VLProcessor(ProcessorMixin):
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = Qwen2VLProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
if video_sizes is not None:
videos_kwargs = Qwen2VLProcessorKwargs._defaults.get("videos_kwargs", {})
videos_kwargs.update(kwargs)
num_video_patches = [
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
for video_size in video_sizes
]
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
vision_data["num_video_tokens"] = num_video_tokens
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -204,5 +204,35 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
tensor_type=return_tensors,
)
def get_num_of_video_patches(self, num_frames: int, height: int, width: int, videos_kwargs=None):
"""
A utility that returns number of video patches a given video size.
Args:
num_frames (`int`):
Number of frames in the input video.
height (`int`):
Height of the input video.
width (`int`):
Width of the input video.
videos_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the video processor.
Returns:
`Tuple(int, int)`: Number of placeholder tokens required and number of patches per image.
"""
min_pixels = videos_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
max_pixels = videos_kwargs.get("max_pixels", None) or self.size["longest_edge"]
patch_size = videos_kwargs.get("patch_size", None) or self.patch_size
merge_size = videos_kwargs.get("merge_size", None) or self.merge_size
temporal_patch_size = videos_kwargs.get("temporal_patch_size", None) or self.temporal_patch_size
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
grid_t = num_frames // temporal_patch_size
return grid_t * grid_h * grid_w
__all__ = ["Qwen2VLVideoProcessor"]

View File

@ -847,5 +847,46 @@ class SmolVLMImageProcessor(BaseImageProcessor):
return encoding
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
do_image_splitting = images_kwargs.get("do_image_splitting", None) or self.do_image_splitting
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
size = images_kwargs.get("size", None) or self.size
if do_image_splitting:
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096)
aspect_ratio = width / height
if width >= height:
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
resized_height = int(width / aspect_ratio)
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
elif height > width:
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
resized_width = int(height * aspect_ratio)
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
max_height = max_width = max_image_size["longest_edge"]
if resized_height > max_height or resized_width > max_width:
# Calculate the number of splits
num_rows = math.ceil(resized_height / max_height)
num_cols = math.ceil(resized_width / max_width)
num_patches = num_rows * num_cols + 1
return num_patches
__all__ = ["SmolVLMImageProcessor"]

View File

@ -22,6 +22,7 @@ import os
import sys
import typing
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict, Union
@ -120,6 +121,8 @@ class TextKwargs(TypedDict, total=False):
Whether or not to print more information and warnings.
padding_side (`str`, *optional*):
The side on which padding will be applied.
return_mm_token_type_ids (`bool`, *optional*):
Whether to return multimodal token type ids indicating mm placeholder token positions.
"""
text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
@ -140,6 +143,7 @@ class TextKwargs(TypedDict, total=False):
return_length: Optional[bool]
verbose: Optional[bool]
padding_side: Optional[str]
return_mm_token_type_ids: Optional[bool]
class ImagesKwargs(TypedDict, total=False):
@ -455,6 +459,32 @@ class AllKwargsForChatTemplate(
}
@dataclass
class MultiModalData:
"""
Dataclass that holds extra useful data for processing
multimodal data. Processors currently cannot return keys,
unless it is used in model's forward. Thus we have helper
methods that calculate and return useful data from processing
input multimodals (images/videos).
Note that this dataclass is aimed to be used only in vLLM
and we might change its API in the future.
"""
num_image_tokens: list[int] = None
num_video_tokens: list[int] = None
num_audio_tokens: list[int] = None
num_image_patches: list[int] = None
def __contains__(self, key):
return hasattr(self, key) and getattr(self, key) is not None
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
class ProcessorMixin(PushToHubMixin):
"""
This is a mixin used to provide saving/loading functionality for all processor classes.