Merge branch 'main' into better-greedy-msg

This commit is contained in:
Manuel de Prada Corral 2025-05-27 16:17:22 +02:00 committed by GitHub
commit a5f7900b5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 1423 additions and 149 deletions

View File

@ -0,0 +1,59 @@
name: Self-hosted runner scale set (AMD mi300 scheduled CI caller)
# Note: For every job in this workflow, the name of the runner scale set is finalized in the runner yaml i.e. huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml
# For example, 1gpu scale set: amd-mi300-ci-1gpu
# 2gpu scale set: amd-mi300-ci-2gpu
on:
workflow_run:
workflows: ["Self-hosted runner (AMD scheduled CI caller)"]
branches: ["main"]
types: [completed]
push:
branches:
- run_amd_scheduled_ci_caller*
jobs:
model-ci:
name: Model CI
uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml@main
with:
job: run_models_gpu
slack_report_channel: "#amd-hf-ci"
runner_scale_set: amd-mi300-ci
docker: huggingface/transformers-pytorch-amd-gpu
ci_event: Scheduled CI (AMD) - mi300
secrets: inherit
torch-pipeline:
name: Torch pipeline CI
uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml@main
with:
job: run_pipelines_torch_gpu
slack_report_channel: "#amd-hf-ci"
runner_scale_set: amd-mi300-ci
docker: huggingface/transformers-pytorch-amd-gpu
ci_event: Scheduled CI (AMD) - mi300
secrets: inherit
example-ci:
name: Example CI
uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml@main
with:
job: run_examples_gpu
slack_report_channel: "#amd-hf-ci"
runner_scale_set: amd-mi300-ci
docker: huggingface/transformers-pytorch-amd-gpu
ci_event: Scheduled CI (AMD) - mi300
secrets: inherit
deepspeed-ci:
name: DeepSpeed CI
uses: huggingface/hf-workflows/.github/workflows/transformers_amd_ci_scheduled_arc_scale_set.yaml@main
with:
job: run_torch_cuda_extensions_gpu
slack_report_channel: "#amd-hf-ci"
runner_scale_set: amd-mi300-ci
docker: huggingface/transformers-pytorch-deepspeed-amd-gpu
ci_event: Scheduled CI (AMD) - mi300
secrets: inherit

View File

@ -315,6 +315,7 @@ device = "cuda"
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
model.train()
model.codec_model.eval()
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz

View File

@ -13,9 +13,11 @@
# limitations under the License.
import copy
import json
import os
import platform
import re
import string
import time
import warnings
@ -25,8 +27,15 @@ from threading import Thread
from typing import Optional
import yaml
from huggingface_hub.utils import disable_progress_bars
from transformers import AutoTokenizer, GenerationConfig, TextIteratorStreamer
from transformers import (
AutoTokenizer,
GenerationConfig,
PreTrainedTokenizer,
TextIteratorStreamer,
logging,
)
from transformers.utils import is_rich_available, is_torch_available
from . import BaseTransformersCLICommand
@ -43,7 +52,7 @@ if is_rich_available():
if is_torch_available():
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
@ -63,6 +72,7 @@ DEFAULT_EXAMPLES = {
"numbers": {"text": "Count to 10 but skip every number ending with an 'e'"},
"birds": {"text": "Why aren't birds real?"},
"socks": {"text": "Why is it important to eat socks after meditating?"},
"numbers2": {"text": "Which number is larger, 9.9 or 9.11?"},
}
# Printed at the start of a chat session
@ -71,7 +81,7 @@ HELP_STRING_MINIMAL = """
**TRANSFORMERS CHAT INTERFACE**
Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
- **!help**: shows all available commands
- **!help**: shows all available commands (set generation settings, save chat, etc.)
- **!status**: shows the current status of the model and generation settings
- **!clear**: clears the current conversation and starts a new one
- **!exit**: closes the interface
@ -135,6 +145,9 @@ class RichInterface:
for i, outputs in enumerate(output_stream):
if not outputs or i == 0:
continue
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
text += outputs
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
@ -219,6 +232,7 @@ class ChatArguments:
system_prompt: Optional[str] = field(default=None, metadata={"help": "System prompt."})
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history."})
examples_path: Optional[str] = field(default=None, metadata={"help": "Path to a yaml file with examples."})
verbose: bool = field(default=False, metadata={"help": "Whether to show runtime warnings in the chat interface."})
# Generation settings
generation_config: Optional[str] = field(
@ -241,7 +255,9 @@ class ChatArguments:
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
eos_tokens: Optional[str] = field(
default=None,
metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated."},
metadata={
"help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated."
},
)
eos_token_ids: Optional[str] = field(
default=None,
@ -426,6 +442,9 @@ class ChatCommand(BaseTransformersCLICommand):
# 2. b. strings should be quoted
def is_number(s: str) -> bool:
# handle negative numbers
if s.startswith("-"):
s = s[1:]
return s.replace(".", "", 1).isdigit()
generate_flags_as_dict = {k: f'"{v}"' if not is_number(v) else v for k, v in generate_flags_as_dict.items()}
@ -459,16 +478,19 @@ class ChatCommand(BaseTransformersCLICommand):
return processed_generate_flags
def get_generation_parameterization(
self, args: ChatArguments, tokenizer: AutoTokenizer
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
) -> tuple[GenerationConfig, dict]:
"""
Returns a GenerationConfig object holding the generation parameters for the CLI command.
"""
# No generation config arg provided -> use base generation config, apply CLI defaults
# No generation config arg provided -> use default generation config, apply CLI defaults
if args.generation_config is None:
generation_config = GenerationConfig()
# We start off from the checkpoint's generation config
generation_config = copy.deepcopy(model.generation_config)
# Apply deprecated CLI args on top of the default generation config
pad_token_id, eos_token_ids = self.parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
pad_token_id, eos_token_ids = self.parse_eos_tokens(
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
)
deprecated_kwargs = {
"max_new_tokens": args.max_new_tokens,
"do_sample": args.do_sample,
@ -499,13 +521,16 @@ class ChatCommand(BaseTransformersCLICommand):
@staticmethod
def parse_eos_tokens(
tokenizer: AutoTokenizer, eos_tokens: Optional[str], eos_token_ids: Optional[str]
tokenizer: PreTrainedTokenizer,
generation_config: GenerationConfig,
eos_tokens: Optional[str],
eos_token_ids: Optional[str],
) -> tuple[int, list[int]]:
"""Retrieves the pad token ID and all possible EOS token IDs."""
if tokenizer.pad_token_id is None:
pad_token_id = tokenizer.eos_token_id
if generation_config.pad_token_id is None:
pad_token_id = generation_config.eos_token_id
else:
pad_token_id = tokenizer.pad_token_id
pad_token_id = generation_config.pad_token_id
all_eos_token_ids = []
@ -516,7 +541,7 @@ class ChatCommand(BaseTransformersCLICommand):
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
if len(all_eos_token_ids) == 0:
all_eos_token_ids.append(tokenizer.eos_token_id)
all_eos_token_ids.append(generation_config.eos_token_id)
return pad_token_id, all_eos_token_ids
@ -583,6 +608,7 @@ class ChatCommand(BaseTransformersCLICommand):
Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
generation config (e.g. set a new flag).
"""
valid_command = True
if user_input == "!clear":
chat = self.clear_chat_history(args.system_prompt)
@ -644,10 +670,11 @@ class ChatCommand(BaseTransformersCLICommand):
)
else:
valid_command = False
interface.print_color(text=f"'{user_input}' is not a valid command. Showing help message.", color="red")
interface.print_help()
return chat, generation_config, model_kwargs
return chat, valid_command, generation_config, model_kwargs
# -----------------------------------------------------------------------------------------------------------------
# Main logic
@ -671,7 +698,12 @@ class ChatCommand(BaseTransformersCLICommand):
model, tokenizer = self.load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer)
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
# if not verbose -> disable warnings, progress bars, etc in the chat interface
if not args.verbose:
logging.set_verbosity_error()
disable_progress_bars()
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
interface.clear()
@ -689,7 +721,7 @@ class ChatCommand(BaseTransformersCLICommand):
if user_input == "!exit":
break
else:
chat, generation_config, model_kwargs = self.handle_non_exit_user_commands(
chat, valid_command, generation_config, model_kwargs = self.handle_non_exit_user_commands(
user_input=user_input,
args=args,
interface=interface,
@ -699,7 +731,7 @@ class ChatCommand(BaseTransformersCLICommand):
chat=chat,
)
# `!example` sends a user message to the model
if not user_input.startswith("!example"):
if not valid_command or not user_input.startswith("!example"):
continue
else:
chat.append({"role": "user", "content": user_input})

View File

@ -822,7 +822,7 @@ class GenerationConfig(PushToHubMixin):
warning_message = (
f"The following generation flags are not valid and may be ignored: {attributes_with_issues}."
)
if logger.getEffectiveLevel() >= logging.WARNING:
if logging.get_verbosity() >= logging.WARNING:
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
logger.warning(warning_message)
logger.info(info_message)

View File

@ -1577,7 +1577,8 @@ def _find_mismatched_keys(
# This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
# Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights.
if not (
new_state_dict[key].shape[-1] == 1
is_quantized
and new_state_dict[key].shape[-1] == 1
and new_state_dict[key].numel() * 2 == model_state_dict[key].numel()
):
mismatched_keys.append(key)
@ -3652,7 +3653,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
for key, value in state_dict.items():
for pattern, replacement in reverse_key_mapping.items():
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*?\)", "", pattern)
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:

View File

@ -500,5 +500,26 @@ class AriaImageProcessor(BaseImageProcessor):
]
return patches
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
split_image = images_kwargs.get("split_image", None) or self.split_image
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
return num_patches
__all__ = ["AriaImageProcessor"]

View File

@ -34,7 +34,7 @@ from ...image_utils import (
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import PreTrainedModel
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils import PreTokenizedInput, TextInput
from ...utils import LossKwargs, TensorType, auto_docstring, can_return_tuple, logging
from ...utils.import_utils import is_torch_available
@ -884,11 +884,33 @@ class AriaImageProcessor(BaseImageProcessor):
]
return patches
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
split_image = images_kwargs.get("split_image", None) or self.split_image
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
return num_patches
class AriaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"max_image_size": 980,
@ -978,10 +1000,7 @@ class AriaProcessor(ProcessorMixin):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if images is not None:
image_inputs = self.image_processor(
images,
**output_kwargs["images_kwargs"],
)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
# expand the image_token according to the num_crops and tokens per image
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
prompt_strings = []
@ -995,11 +1014,44 @@ class AriaProcessor(ProcessorMixin):
prompt_strings = text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -20,9 +20,11 @@
# limitations under the License.
from typing import Dict, List, Optional, Union
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils import PreTokenizedInput, TextInput
from ...utils import TensorType
from ..auto import AutoTokenizer
@ -32,6 +34,7 @@ class AriaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"max_image_size": 980,
@ -121,10 +124,7 @@ class AriaProcessor(ProcessorMixin):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if images is not None:
image_inputs = self.image_processor(
images,
**output_kwargs["images_kwargs"],
)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
# expand the image_token according to the num_crops and tokens per image
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
prompt_strings = []
@ -138,11 +138,44 @@ class AriaProcessor(ProcessorMixin):
prompt_strings = text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -13,22 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from transformers.processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_utils import (
ImageInput,
make_flat_list_of_images,
)
from ...processing_utils import (
ImagesKwargs,
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
class AyaVisionImagesKwargs(ImagesKwargs, total=False):
@ -43,6 +44,7 @@ class AyaVisionProcessorKwargs(ProcessingKwargs, total=False):
"text_kwargs": {
"padding_side": "left",
"padding": True,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"crop_to_patches": True,
@ -121,7 +123,6 @@ class AyaVisionProcessor(ProcessorMixin):
super().__init__(image_processor, tokenizer, chat_template=chat_template)
self.image_token = image_token
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.patch_size = patch_size * downsample_factor
self.img_size = img_size
@ -131,6 +132,10 @@ class AyaVisionProcessor(ProcessorMixin):
self.img_line_break_token = img_line_break_token
self.tile_token = tile_token
self.tile_global_token = tile_global_token
self.image_token_id = tokenizer.convert_tokens_to_ids(self.img_patch_token)
self.image_ids = tokenizer.convert_tokens_to_ids(
[img_patch_token, tile_token, tile_global_token, start_of_img_token, end_of_img_token]
)
def _prompt_split_image(self, num_patches):
"""
@ -226,11 +231,49 @@ class AyaVisionProcessor(ProcessorMixin):
text = processed_text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = AyaVisionProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
token_per_patch = (self.img_size // self.patch_size) ** 2
num_image_tokens = [
token_per_patch + 3 + sum(token_per_patch + 1 for _ in range(1, num_patches))
for num_patches in num_image_patches
] # Add +3 and +1 for BOI/EOI and image tile tokens
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -18,9 +18,18 @@ Processor class for Chameleon.
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
@ -34,6 +43,7 @@ class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
"text_kwargs": {
"padding": False,
"return_for_text_completion": False,
"return_mm_token_type_ids": False,
},
"common_kwargs": {
"return_tensors": "pt",
@ -73,6 +83,10 @@ class ChameleonProcessor(ProcessorMixin):
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
) # fixed tokens for start and end, so can hardcode
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token)
self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id]
super().__init__(image_processor, tokenizer)
@ -141,14 +155,45 @@ class ChameleonProcessor(ProcessorMixin):
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
prompt_strings.append(sample)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(prompt_strings, data, modalities=["image"])
image_inputs = {}
if images is not None:
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
return BatchFeature(data=data, tensor_type=return_tensors)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
# add 2 for BOI and EOI tokens
num_image_tokens = [self.image_seq_length + 2] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):

View File

@ -24,7 +24,7 @@ from typing import ClassVar, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
from ...utils import is_torch_available
@ -256,6 +256,25 @@ class ColPaliProcessor(ProcessorMixin):
return batch_query
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = [self.image_seq_length] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -981,22 +981,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
# =======================================
# TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
with torch.no_grad():
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
# =======================================
audio_token_id = self.config.audio_token_id
audio_token_mask = input_ids == audio_token_id
@ -1018,6 +1019,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
if labels is not None:
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
# mask depth decoder
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100

View File

@ -595,22 +595,23 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
# =======================================
# TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
with torch.no_grad():
audio_tokens_list = []
for batch_input_values, batch_input_values_cutoffs in zip(input_values, input_values_cutoffs):
batch_input_values_cutoffs = batch_input_values_cutoffs[batch_input_values_cutoffs >= 0]
for i in range(batch_input_values_cutoffs.shape[0] - 1):
start_idx = batch_input_values_cutoffs[i]
end_idx = batch_input_values_cutoffs[i + 1]
audio_batch = batch_input_values[..., start_idx:end_idx]
codec_outputs = self.codec_model.encode(audio_batch.unsqueeze(0))
codebook_ids = codec_outputs.audio_codes.transpose(1, -1)
audio_tokens_list.append(codebook_ids[0])
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
max_audio_frames = max(el.shape[0] for el in audio_tokens_list)
batched_audio_token_ids = torch.stack(
[nn.functional.pad(el, (0, 0, 0, max_audio_frames - el.shape[0])) for el in audio_tokens_list]
)
audio_codes_mask = self.codec_model.get_audio_codes_mask(input_values_mask)
# =======================================
audio_token_id = self.config.audio_token_id
audio_token_mask = input_ids == audio_token_id
@ -632,6 +633,7 @@ class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin):
if labels is not None:
labels_expanded = labels.unsqueeze(-1).repeat(1, 1, self.config.num_codebooks)
labels_expanded[audio_token_mask] = batched_audio_token_ids[audio_codes_mask]
labels_expanded[audio_eos_token_mask] = audio_eos_frame_ids
# mask depth decoder
depth_decoder_ignore_frames_idxs = (labels == -101).nonzero(as_tuple=True)
labels_expanded[depth_decoder_ignore_frames_idxs[0], depth_decoder_ignore_frames_idxs[1], 1:] = -100

View File

@ -353,7 +353,11 @@ class CsmProcessor(ProcessorMixin):
else:
skip_frames_idxs = audio_frame_idxs
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
labels = torch.where(
(data["input_ids"] == self.audio_token_id) | (data["input_ids"] == self.audio_eos_token_id),
data["input_ids"],
-100,
)
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
data["labels"] = labels

View File

@ -16,10 +16,17 @@
from typing import List, Optional, Union
import numpy as np
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_vision_available
if is_vision_available():
from .image_processing_emu3 import smart_resize
class Emu3TextKwargs(TextKwargs, total=False):
@ -37,6 +44,7 @@ class Emu3ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"return_for_image_generation": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"ratio": "1:1",
@ -166,7 +174,7 @@ class Emu3Processor(ProcessorMixin):
image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'<placeholder>' * image_seq_length}{image_end_tokens}"
sample = sample.replace(self.image_token, image_placeholder, 1)
sample = f"{self.bos_token}{sample}" # add BOS because PT tokenizer doesn't add it
sample = f"{self.bos_token}{sample}" # add BOS because GPT tokenizer doesn't add it
prompt_strings.append(sample)
text = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
@ -179,12 +187,51 @@ class Emu3Processor(ProcessorMixin):
# else just generate from text-only input, and we do no special treatment for text
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
data = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, data, modalities=["image"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
data.update(**image_features)
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data=data, tensor_type=return_tensors)
return BatchFeature(data={**text_inputs, **image_features}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = []
for height, width in image_sizes:
height, width = smart_resize(
height,
width,
self.image_processor.spatial_factor,
self.image_processor.min_pixels,
self.image_processor.max_pixels,
)
height = height // self.downsample_ratio
width = width // self.downsample_ratio
image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code
num_image_tokens.append(image_seq_length)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def calculate_generate_size(self, ratio, image_area, spatial_factor):
width, height = map(int, ratio.split(":"))

View File

@ -130,7 +130,7 @@ class FuyuModel(FuyuPreTrainedModel):
)
return output_embeddings
def get_image_features(self, pixel_values: torch.FloatTensor):
def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.

View File

@ -22,7 +22,13 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_available, logging, requires_backends
from ...utils.import_utils import requires
@ -64,6 +70,7 @@ class FuyuProcessorKwargs(ProcessingKwargs, total=False):
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
"return_mm_token_type_ids": False,
},
"images_kwargs": {},
}
@ -355,6 +362,8 @@ class FuyuProcessor(ProcessorMixin):
self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
self.pad_token_id = 0
self.dummy_image_index = -1
self.image_token_id = tokenizer.encode("|SPEAKER|", add_special_tokens=False)[1]
self.image_newline_id = tokenizer.encode("|NEWLINE|", add_special_tokens=False)[1]
def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool):
max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
@ -403,6 +412,11 @@ class FuyuProcessor(ProcessorMixin):
for key in batched_keys:
batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
# Cast images to tensor as well, if only one image passed and no padding needed
# NOTE: vLLM expects all processor outputs to be a tensor
if len(batched_inputs["image_patches"]) == 1:
batched_inputs["image_patches"] = torch.cat(batched_inputs["image_patches"], dim=0)
return batched_inputs
def get_sample_encoding(
@ -517,6 +531,7 @@ class FuyuProcessor(ProcessorMixin):
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
raise ValueError("`return_attention_mask=False` is not supported for this model.")
@ -550,8 +565,6 @@ class FuyuProcessor(ProcessorMixin):
# --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
# --- Use self.image_processor again to obtain the full token ids and batch inputs ---
@ -565,16 +578,63 @@ class FuyuProcessor(ProcessorMixin):
scale_factors=[scale_factor],
image_unpadded_heights=torch.tensor([image_unpadded_height]),
image_unpadded_widths=torch.tensor([image_unpadded_width]),
image_placeholder_id=image_placeholder_id,
image_newline_id=image_newline_id,
image_placeholder_id=self.image_token_id,
image_newline_id=self.image_newline_id,
tensor_batch_images=tensor_batch_image.unsqueeze(0),
)
all_encodings.append(sample_encoding)
batch_encoding = self._left_pad_inputs_with_attention_mask(
model_inputs=all_encodings, return_attention_mask=True
)
if return_mm_token_type_ids:
input_ids = batch_encoding["input_ids"]
mm_token_type_ids = torch.zeros_like(input_ids)
mm_token_type_ids[input_ids == self.image_token_id] = 1
mm_token_type_ids[input_ids == self.image_newline_id] = 1
batch_encoding["mm_token_type_ids"] = mm_token_type_ids
return FuyuBatchFeature(data=batch_encoding)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
size = kwargs.get("size", None) or self.image_processor.size
padded_height, padded_width = size["height"], size["width"]
num_image_tokens = []
num_image_patches = [1] * len(image_sizes)
for image_size in image_sizes:
height_scale_factor = padded_height / image_size[0]
width_scale_factor = padded_width / image_size[1]
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
# We can use torch here because Fuyu processor has hard dependency on torch
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
image_input=torch.zeros(1, 1, 3, padded_height, padded_width),
image_present=torch.ones(1, 1, 1),
image_unpadded_h=torch.tensor([[int(image_size[0] * optimal_scale_factor)]]),
image_unpadded_w=torch.tensor([[int(image_size[1] * optimal_scale_factor)]]),
image_placeholder_id=0, # dummy ids, we can be sure `id=0` is never out-of-range
image_newline_id=0,
variable_sized=True,
)
num_image_tokens.append(model_image_input["image_input_ids"][0][0].shape[-1])
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def post_process_box_coordinates(self, outputs, target_sizes=None):
"""
Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.

View File

@ -782,7 +782,7 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs)
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
@ -792,8 +792,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1, we need to unmask it
return token_type_ids[batch_idx, kv_idx] == 1
# If the difference is less than image size, both are part of the same image block
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
# If it's 1 for both query and key/value, we are in an image block
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block
return inner_mask
@ -945,7 +950,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
)
# Create the masks
@ -1211,7 +1216,9 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
)
return create_masks_for_generate(**mask_kwargs)

View File

@ -722,7 +722,7 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs)
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
@ -732,8 +732,13 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1, we need to unmask it
return token_type_ids[batch_idx, kv_idx] == 1
# If the difference is less than image size, both are part of the same image block
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
# If it's 1 for both query and key/value, we are in an image block
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block
return inner_mask
@ -836,7 +841,7 @@ class Gemma3Model(PaliGemmaModel):
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device)
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
)
# Create the masks
@ -1055,7 +1060,9 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
# Add the token type ids mask for generate as well
if token_type_ids is not None and input_embeds.shape[1] != 1:
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
)
return create_masks_for_generate(**mask_kwargs)

View File

@ -20,7 +20,7 @@ import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, make_nested_list_of_images
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import to_py_obj
@ -38,6 +38,7 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": True,
},
"images_kwargs": {
"do_pan_and_scan": False,
@ -137,17 +138,42 @@ class Gemma3Processor(ProcessorMixin):
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
# Add token type ids manually, as tokenizer can't do arbitrary position token types
array_ids = text_inputs["input_ids"]
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(array_ids)
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
# NOTE: no image cropping supported yet
num_image_tokens = [self.image_seq_length] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""

View File

@ -491,5 +491,33 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
return processed_images
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
min_patches = images_kwargs.get("min_patches", None) or self.min_patches
max_patches = images_kwargs.get("max_patches", None) or self.max_patches
patch_size = images_kwargs.get("size", None) or self.size
crop_to_patches = images_kwargs.get("crop_to_patches", None) or self.crop_to_patches
num_patches = 1
if crop_to_patches and max_patches > 1:
num_columns, num_rows = get_optimal_tiled_canvas(
(height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches
)
num_patches += num_columns * num_rows
return num_patches
__all__ = ["GotOcr2ImageProcessor"]

View File

@ -228,5 +228,33 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
)
def get_number_of_image_tokens(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
min_patches = images_kwargs.get("min_patches", None) or self.min_patches
max_patches = images_kwargs.get("max_patches", None) or self.max_patches
patch_size = images_kwargs.get("size", None) or self.size
crop_to_patches = images_kwargs.get("crop_to_patches", None) or self.crop_to_patches
num_patches = 1
if crop_to_patches and max_patches > 1:
num_columns, num_rows = get_optimal_tiled_canvas(
(height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches
)
num_patches += num_columns * num_rows
return num_patches
__all__ = ["GotOcr2ImageProcessorFast"]

View File

@ -850,5 +850,46 @@ class Idefics3ImageProcessor(BaseImageProcessor):
return encoding
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
do_image_splitting = images_kwargs.get("do_image_splitting", None) or self.do_image_splitting
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
size = images_kwargs.get("size", None) or self.size
if do_image_splitting:
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096)
aspect_ratio = width / height
if width >= height:
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
resized_height = int(width / aspect_ratio)
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
elif height > width:
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
resized_width = int(height * aspect_ratio)
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
max_height = max_width = max_image_size["longest_edge"]
if resized_height > max_height or resized_width > max_width:
# Calculate the number of splits
num_rows = math.ceil(resized_height / max_height)
num_cols = math.ceil(resized_width / max_width)
num_patches = num_rows * num_cols + 1
return num_patches
__all__ = ["Idefics3ImageProcessor"]

View File

@ -16,13 +16,16 @@
Processor class for Idefics3.
"""
import math
import re
from itertools import accumulate
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
from ...utils import logging
@ -98,6 +101,7 @@ class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
"add_special_tokens": True,
"padding": False,
"is_split_into_words": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"return_row_col_info": True,
@ -146,6 +150,12 @@ class Idefics3Processor(ProcessorMixin):
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True).content
self.global_image_tag = "<global-img>" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341
self.image_seq_len = image_seq_len
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.fake_image_token_id = tokenizer.convert_tokens_to_ids(self.fake_image_token)
self.global_image_token_id = tokenizer.convert_tokens_to_ids(self.global_image_tag)
self.row_col_ids = [
tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>") for i in range(6) for j in range(6)
]
# This regex matches one or more occurrences of <global-img> tags (optionally surrounded by newline characters)
# or <row_x_col_y> tags (where x and y are digits, also optionally surrounded by newline characters).
@ -241,6 +251,7 @@ class Idefics3Processor(ProcessorMixin):
)
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
n_images_in_text = []
@ -302,9 +313,11 @@ class Idefics3Processor(ProcessorMixin):
global_img_token = self.global_image_tag
prompt_strings = []
batch_image_seq_lengths = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
image_seq_lengths = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = get_image_prompt_string(
n_rows,
@ -314,8 +327,12 @@ class Idefics3Processor(ProcessorMixin):
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
# Add +2 and +3 for special BOI/EOI/fake_image_wrapper tokens
row_length = (self.image_seq_len + 2) * n_cols + 1
image_seq_lengths.append((self.image_seq_len + 3) + row_length * n_rows)
image_prompt_strings.append(image_prompt_string)
batch_image_seq_lengths.append(image_seq_lengths)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("The image token should be present in the text.")
@ -338,7 +355,59 @@ class Idefics3Processor(ProcessorMixin):
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)
return BatchFeature(inputs, tensor_type=return_tensors)
if return_mm_token_type_ids:
array_ids = np.array(inputs["input_ids"])
mm_token_type_ids = np.zeros_like(array_ids)
for i, seq_lengths in enumerate(batch_image_seq_lengths):
image_start_positions = np.where(array_ids[i] == self.fake_image_token_id)[0]
j = 0
for seq_len in seq_lengths:
if j >= len(image_start_positions):
break
start = image_start_positions[j]
end = start + seq_len
mm_token_type_ids[i, start:end] = 1
j = np.searchsorted(image_start_positions, end)
inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data=inputs, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = Idefics3ProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
base_image_length = self.image_seq_len + 3
col_length = self.image_seq_len + 2
num_image_tokens = []
for num_patches in num_image_patches:
num_cols = num_rows = int(math.sqrt(num_patches - 1))
row_length = col_length * num_cols + 1
num_image_tokens.append(base_image_length + (row_length * num_rows))
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""

View File

@ -13,25 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import numpy as np
from transformers.processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from ...image_processing_utils import BatchFeature
from ...image_utils import (
ImageInput,
concatenate_list,
make_flat_list_of_images,
)
from ...processing_utils import (
ImagesKwargs,
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos
@ -46,6 +45,7 @@ class InternVLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding_side": "left",
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"crop_to_patches": True,
@ -94,9 +94,12 @@ class InternVLProcessor(ProcessorMixin):
self.image_seq_length = image_seq_length
self.start_image_token = tokenizer.start_image_token
self.end_image_token = tokenizer.end_image_token
self.start_image_token_id = tokenizer.start_image_token_id
self.end_image_token_id = tokenizer.end_image_token_id
self.image_token = tokenizer.context_image_token
self.video_token = tokenizer.video_token
self.image_token_id = tokenizer.context_image_token_id
self.image_ids = [self.image_token_id, self.start_image_token_id, self.end_image_token_id]
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
@ -261,11 +264,46 @@ class InternVLProcessor(ProcessorMixin):
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = InternVLProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
num_image_patches = [
self.image_processor.get_number_of_image_tokens(*image_size, images_kwargs)
for image_size in image_sizes
]
# Add 2 for BOI and EOI tokens
num_image_tokens = [2 + (self.image_seq_length * num_patches) for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
def sample_indices_fn(
self, metadata: VideoMetadata, num_frames: Optional[int] = None, initial_shift: Union[bool, float, int] = True
):

View File

@ -18,9 +18,17 @@ Processor class for Llava.
from typing import List, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
@ -30,9 +38,7 @@ logger = logging.get_logger(__name__)
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"text_kwargs": {"padding": False, "return_mm_token_type_ids": False},
"images_kwargs": {},
}
@ -89,11 +95,7 @@ class LlavaProcessor(ProcessorMixin):
self.num_additional_image_tokens = num_additional_image_tokens
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.image_token_id = tokenizer.encode(self.image_token, add_special_tokens=False)[0]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
@ -174,10 +176,49 @@ class LlavaProcessor(ProcessorMixin):
prompt_strings.append(sample)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = LlavaProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
crop_size = images_kwargs.get("crop_size", None) or self.image_processor.crop_size
resized_height, resized_width = crop_size["height"], crop_size["width"]
num_image_tokens = (resized_height // self.patch_size) * (resized_width // self.patch_size)
num_image_tokens += self.num_additional_image_tokens
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
num_image_tokens = [num_image_tokens] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -18,10 +18,18 @@ Processor class for LLaVa-NeXT.
from typing import List, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
@ -33,6 +41,7 @@ class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"do_pad": True,
@ -172,9 +181,16 @@ class LlavaNextProcessor(ProcessorMixin):
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
@ -219,6 +235,48 @@ class LlavaNextProcessor(ProcessorMixin):
newline_features = current_height
return (unpadded_features, newline_features)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (List[List[str]], *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
audio_lengths (List[int], *optional*):
The input length formatted as per each audio.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = LlavaNextProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
size = images_kwargs.get("size", None) or self.image_processor.size
size = (
(size["shortest_edge"], size["shortest_edge"])
if "shortest_edge" in size
else (min(size["height"], size["width"]), min(size["height"], size["width"]))
)
processed_height, processed_width = size
batch_num_image_tokens = []
num_image_patches = [1] * len(image_sizes) # llava-next doesn't batch pixels as Idefics, thus `1` patch`
for image_size in image_sizes:
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(
orig_height, orig_width, processed_height, processed_width
)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
batch_num_image_tokens.append(num_image_tokens)
vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -24,7 +24,7 @@ import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_processing_utils import select_best_resolution
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...video_utils import VideoInput
@ -38,6 +38,7 @@ class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"image_kwargs": {},
"videos_kwargs": {},
@ -196,9 +197,16 @@ class LlavaOnevisionProcessor(ProcessorMixin):
text = [sample.replace(self.video_token, self.video_token * num_video_tokens) for sample in text]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs}, tensor_type=return_tensors)
def _expand_image_tokens(
@ -285,6 +293,48 @@ class LlavaOnevisionProcessor(ProcessorMixin):
return (unpadded_features, newline_features)
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (List[List[str]], *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
audio_lengths (List[int], *optional*):
The input length formatted as per each audio.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = LlavaOnevisionProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
size = images_kwargs.get("size", None) or self.image_processor.size
size = (
(size["shortest_edge"], size["shortest_edge"])
if "shortest_edge" in size
else (min(size["height"], size["width"]), min(size["height"], size["width"]))
)
processed_height, processed_width = size
batch_num_image_tokens = []
num_image_patches = [1] * len(image_sizes) # llava-ov doesn't batch pixels as Idefics, thus `1` patch`
for image_size in image_sizes:
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(
orig_height, orig_width, processed_height, processed_width
)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
batch_num_image_tokens.append(num_image_tokens)
vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -18,10 +18,13 @@ Processor class for PaliGemma.
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
from ...processing_utils import (
ImagesKwargs,
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
@ -56,6 +59,7 @@ class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {
"data_format": "channels_first",
@ -299,6 +303,7 @@ class PaliGemmaProcessor(ProcessorMixin):
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
inputs = self.tokenizer(
input_strings,
text_pair=suffix,
@ -310,10 +315,37 @@ class PaliGemmaProcessor(ProcessorMixin):
return_data = {**inputs, "pixel_values": pixel_values}
if return_token_type_ids:
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
labels = np.array(inputs["input_ids"])
labels[np.array(inputs["token_type_ids"]) == 0] = -100
return_data.update({"labels": labels})
if return_mm_token_type_ids:
array_ids = np.array(return_data["input_ids"])
mm_token_type_ids = np.zeros_like(return_data["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
return_data["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data=return_data, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (List[List[str]], *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
to a list containing the number of placeholder tokens required. If the model doesn't accept
a certain modality or no input sizes are provided, the dict value is set to an empty list.
"""
vision_data = {}
if image_sizes is not None:
num_image_tokens = [self.image_seq_length] * len(image_sizes)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
def batch_decode(self, *args, **kwargs):
"""

View File

@ -18,11 +18,23 @@ Processor class for Pixtral.
from typing import List, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...processing_utils import (
MultiModalData,
ProcessingKwargs,
ProcessorMixin,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...utils import is_vision_available, logging
if is_vision_available():
from .image_processing_pixtral import get_resize_output_image_size
logger = logging.get_logger(__name__)
@ -32,6 +44,7 @@ class PixtralProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"images_kwargs": {},
"common_kwargs": {
@ -106,6 +119,10 @@ class PixtralProcessor(ProcessorMixin):
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.image_break_token = image_break_token
self.image_end_token = image_end_token
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
self.image_break_token_id = tokenizer.convert_tokens_to_ids(self.image_break_token)
self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
self.image_ids = [self.image_token_id, self.image_break_token_id, self.image_end_token_id]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
@ -213,10 +230,54 @@ class PixtralProcessor(ProcessorMixin):
prompt_strings.append(sample)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = PixtralProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
size = images_kwargs.get("size", None) or self.image_processor.size
patch_size = self.patch_size * self.spatial_merge_size
num_image_tokens = []
for height, width in image_sizes:
resized_height, resized_width = get_resize_output_image_size(
image=np.zeros((height, width, 3)),
size=(size["longest_edge"], size["longest_edge"]),
patch_size=(patch_size, patch_size),
)
num_height_tokens = resized_height // patch_size
num_width_tokens = resized_width // patch_size
num_image_tokens.append((num_width_tokens + 1) * num_height_tokens)
num_image_patches = [1] * len(image_sizes)
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
return MultiModalData(**vision_data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""

View File

@ -22,6 +22,7 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -48,7 +49,7 @@ from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
from ...processing_utils import MultiModalData, ProcessingKwargs, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torchdynamo_compiling, logging
from ...video_utils import VideoInput
@ -925,6 +926,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"fps": 2.0},
}
@ -1011,10 +1013,12 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
image_grid_thw = image_inputs["image_grid_thw"]
if videos is not None:
# pop fps in advance for passing kwargs validation
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
@ -1050,11 +1054,56 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
if video_sizes is not None:
videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
videos_kwargs.update(kwargs)
num_video_patches = [
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
for video_size in video_sizes
]
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
vision_data["num_video_tokens"] = num_video_tokens
return MultiModalData(**vision_data)
__all__ = [
"Qwen2_5_VLConfig",

View File

@ -25,9 +25,11 @@
# limitations under the License.
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...video_utils import VideoInput
@ -50,6 +52,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"fps": 2.0},
}
@ -149,10 +152,12 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
image_grid_thw = image_inputs["image_grid_thw"]
if videos is not None:
# pop fps in advance for passing kwargs validation
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
if isinstance(fps, (int, float)):
second_per_grid_ts = [self.video_processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
@ -188,11 +193,56 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
if video_sizes is not None:
videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
videos_kwargs.update(kwargs)
num_video_patches = [
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
for video_size in video_sizes
]
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
vision_data["num_video_tokens"] = num_video_tokens
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -490,5 +490,31 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
return BatchFeature(data=data, tensor_type=return_tensors)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of image patches per image.
"""
min_pixels = images_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
max_pixels = images_kwargs.get("max_pixels", None) or self.size["longest_edge"]
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * grid_w
__all__ = ["Qwen2VLImageProcessor"]

View File

@ -402,5 +402,31 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
return BatchFeature(data=data, tensor_type=return_tensors)
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of image patches per image.
"""
min_pixels = images_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
max_pixels = images_kwargs.get("max_pixels", None) or self.size["longest_edge"]
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
return grid_h * grid_w
__all__ = ["Qwen2VLImageProcessorFast"]

View File

@ -23,9 +23,11 @@ Processor class for Qwen2-VL.
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...video_utils import VideoInput
@ -47,6 +49,7 @@ class Qwen2VLProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"return_mm_token_type_ids": False,
},
}
@ -172,10 +175,56 @@ class Qwen2VLProcessor(ProcessorMixin):
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
"""
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
Args:
image_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (height, width) per each image.
video_sizes (`List[List[int]]`, *optional*):
The input sizes formatted as (num_frames, height, width) per each video.
Returns:
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
input modalities, along with other useful data.
"""
vision_data = {}
if image_sizes is not None:
images_kwargs = Qwen2VLProcessorKwargs._defaults.get("images_kwargs", {})
images_kwargs.update(kwargs)
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
num_image_patches = [
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
for image_size in image_sizes
]
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
if video_sizes is not None:
videos_kwargs = Qwen2VLProcessorKwargs._defaults.get("videos_kwargs", {})
videos_kwargs.update(kwargs)
num_video_patches = [
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
for video_size in video_sizes
]
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
vision_data["num_video_tokens"] = num_video_tokens
return MultiModalData(**vision_data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please

View File

@ -204,5 +204,35 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
tensor_type=return_tensors,
)
def get_num_of_video_patches(self, num_frames: int, height: int, width: int, videos_kwargs=None):
"""
A utility that returns number of video patches a given video size.
Args:
num_frames (`int`):
Number of frames in the input video.
height (`int`):
Height of the input video.
width (`int`):
Width of the input video.
videos_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the video processor.
Returns:
`Tuple(int, int)`: Number of placeholder tokens required and number of patches per image.
"""
min_pixels = videos_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
max_pixels = videos_kwargs.get("max_pixels", None) or self.size["longest_edge"]
patch_size = videos_kwargs.get("patch_size", None) or self.patch_size
merge_size = videos_kwargs.get("merge_size", None) or self.merge_size
temporal_patch_size = videos_kwargs.get("temporal_patch_size", None) or self.temporal_patch_size
factor = patch_size * merge_size
resized_height, resized_width = smart_resize(
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
)
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
grid_t = num_frames // temporal_patch_size
return grid_t * grid_h * grid_w
__all__ = ["Qwen2VLVideoProcessor"]

View File

@ -847,5 +847,46 @@ class SmolVLMImageProcessor(BaseImageProcessor):
return encoding
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
"""
A utility that returns number of image patches for a given image size.
Args:
height (`int`):
Height of the input image.
width (`int`):
Width of the input image.
images_kwargs (`dict`, *optional*)
Any kwargs to override defaults of the image processor.
Returns:
`int`: Number of patches per image.
"""
do_image_splitting = images_kwargs.get("do_image_splitting", None) or self.do_image_splitting
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
size = images_kwargs.get("size", None) or self.size
if do_image_splitting:
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096)
aspect_ratio = width / height
if width >= height:
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
resized_height = int(width / aspect_ratio)
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
elif height > width:
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
resized_width = int(height * aspect_ratio)
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
max_height = max_width = max_image_size["longest_edge"]
if resized_height > max_height or resized_width > max_width:
# Calculate the number of splits
num_rows = math.ceil(resized_height / max_height)
num_cols = math.ceil(resized_width / max_width)
num_patches = num_rows * num_cols + 1
return num_patches
__all__ = ["SmolVLMImageProcessor"]

View File

@ -22,6 +22,7 @@ import os
import sys
import typing
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict, Union
@ -120,6 +121,8 @@ class TextKwargs(TypedDict, total=False):
Whether or not to print more information and warnings.
padding_side (`str`, *optional*):
The side on which padding will be applied.
return_mm_token_type_ids (`bool`, *optional*):
Whether to return multimodal token type ids indicating mm placeholder token positions.
"""
text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
@ -140,6 +143,7 @@ class TextKwargs(TypedDict, total=False):
return_length: Optional[bool]
verbose: Optional[bool]
padding_side: Optional[str]
return_mm_token_type_ids: Optional[bool]
class ImagesKwargs(TypedDict, total=False):
@ -455,6 +459,32 @@ class AllKwargsForChatTemplate(
}
@dataclass
class MultiModalData:
"""
Dataclass that holds extra useful data for processing
multimodal data. Processors currently cannot return keys,
unless it is used in model's forward. Thus we have helper
methods that calculate and return useful data from processing
input multimodals (images/videos).
Note that this dataclass is aimed to be used only in vLLM
and we might change its API in the future.
"""
num_image_tokens: list[int] = None
num_video_tokens: list[int] = None
num_audio_tokens: list[int] = None
num_image_patches: list[int] = None
def __contains__(self, key):
return hasattr(self, key) and getattr(self, key) is not None
def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
class ProcessorMixin(PushToHubMixin):
"""
This is a mixin used to provide saving/loading functionality for all processor classes.

View File

@ -250,7 +250,10 @@ class BaseVideoProcessor(BaseImageProcessorFast):
videos: VideoInput,
**kwargs: Unpack[VideosKwargs],
) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:

View File

@ -696,11 +696,13 @@ def group_videos_by_shape(
grouped_videos_index = {}
for i, video in enumerate(videos):
shape = video.shape[-2::]
num_frames = video.shape[-4] # video format BTCHW
shape = (num_frames, *shape)
if shape not in grouped_videos:
grouped_videos[shape] = []
grouped_videos[shape].append(video)
grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
# stack videos with the same shape
# stack videos with the same size and number of frames
grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
return grouped_videos, grouped_videos_index

View File

@ -62,6 +62,20 @@ class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
@require_torch
def test_call_with_suffix(self):
input_str = "lower newer"
suffix = "upper older longer string"
image_input = self.prepare_image_inputs()
processor = self.get_processor()
inputs = processor(text=input_str, images=image_input, suffix=suffix)
self.assertTrue("labels" in inputs)
self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0]))
inputs = processor(text=input_str, images=image_input, suffix=suffix, return_tensors="pt")
self.assertTrue("labels" in inputs)
self.assertEqual(len(inputs["labels"][0]), len(inputs["input_ids"][0]))
def test_text_with_image_tokens(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

View File

@ -30,7 +30,7 @@ from transformers.testing_utils import (
require_torchvision,
require_vision,
)
from transformers.video_utils import make_batched_videos
from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos
if is_torch_available():
@ -43,9 +43,9 @@ if is_vision_available():
from transformers.video_utils import VideoMetadata, load_video
def get_random_video(height, width, return_torch=False):
def get_random_video(height, width, num_frames=8, return_torch=False):
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
video = np.array(([random_frame] * 8))
video = np.array(([random_frame] * num_frames))
if return_torch:
# move channel first
return torch.from_numpy(video).permute(0, 3, 1, 2)
@ -189,6 +189,53 @@ class BaseVideoProcessorTester(unittest.TestCase):
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
def test_group_and_reorder_videos(self):
"""Tests that videos can be grouped by frame size and number of frames"""
video_1 = get_random_video(20, 20, num_frames=3, return_torch=True)
video_2 = get_random_video(20, 20, num_frames=5, return_torch=True)
# Group two videos of same size but different number of frames
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group two videos of different size but same number of frames
video_3 = get_random_video(15, 20, num_frames=3, return_torch=True)
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group all three videos where some have same size or same frame count
# But since none have frames and sizes identical, we'll have 3 groups
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3])
self.assertEqual(len(grouped_videos), 3)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 3)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group if we had some videos with identical shapes
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group if we had all videos with identical shapes
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1])
self.assertEqual(len(grouped_videos), 1)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 1)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
@require_vision
@require_av