mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
[transformers x vLLM] standardize processors (#37915)
* standardize * fix tests * batch update some processors, not final yet * oke, now I tested that everything indeed runs. Still needs prettification * emu3 * fixup * gemma3 but it doesn't generate anything * fuyu * update * why? * Update src/transformers/models/aya_vision/processing_aya_vision.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address comments * bc * why do we need to guard import this every time? * i hate guarded imports * i am blind --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
b5ececb900
commit
9e1017b479
@ -500,5 +500,26 @@ class AriaImageProcessor(BaseImageProcessor):
|
|||||||
]
|
]
|
||||||
return patches
|
return patches
|
||||||
|
|
||||||
|
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number of image patches for a given image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
images_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the image processor.
|
||||||
|
Returns:
|
||||||
|
`int`: Number of patches per image.
|
||||||
|
"""
|
||||||
|
split_image = images_kwargs.get("split_image", None) or self.split_image
|
||||||
|
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
|
||||||
|
|
||||||
|
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
|
||||||
|
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
|
||||||
|
return num_patches
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["AriaImageProcessor"]
|
__all__ = ["AriaImageProcessor"]
|
||||||
|
@ -34,7 +34,7 @@ from ...image_utils import (
|
|||||||
)
|
)
|
||||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils import PreTokenizedInput, TextInput
|
from ...tokenization_utils import PreTokenizedInput, TextInput
|
||||||
from ...utils import LossKwargs, TensorType, auto_docstring, can_return_tuple, logging
|
from ...utils import LossKwargs, TensorType, auto_docstring, can_return_tuple, logging
|
||||||
from ...utils.import_utils import is_torch_available
|
from ...utils.import_utils import is_torch_available
|
||||||
@ -884,11 +884,33 @@ class AriaImageProcessor(BaseImageProcessor):
|
|||||||
]
|
]
|
||||||
return patches
|
return patches
|
||||||
|
|
||||||
|
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number of image patches for a given image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
images_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the image processor.
|
||||||
|
Returns:
|
||||||
|
`int`: Number of patches per image.
|
||||||
|
"""
|
||||||
|
split_image = images_kwargs.get("split_image", None) or self.split_image
|
||||||
|
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
|
||||||
|
|
||||||
|
resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
|
||||||
|
num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
|
||||||
|
return num_patches
|
||||||
|
|
||||||
|
|
||||||
class AriaProcessorKwargs(ProcessingKwargs, total=False):
|
class AriaProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"max_image_size": 980,
|
"max_image_size": 980,
|
||||||
@ -978,10 +1000,7 @@ class AriaProcessor(ProcessorMixin):
|
|||||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_inputs = self.image_processor(
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
images,
|
|
||||||
**output_kwargs["images_kwargs"],
|
|
||||||
)
|
|
||||||
# expand the image_token according to the num_crops and tokens per image
|
# expand the image_token according to the num_crops and tokens per image
|
||||||
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
|
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
|
||||||
prompt_strings = []
|
prompt_strings = []
|
||||||
@ -995,11 +1014,44 @@ class AriaProcessor(ProcessorMixin):
|
|||||||
prompt_strings = text
|
prompt_strings = text
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||||
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
|
||||||
|
num_image_patches = [
|
||||||
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||||||
|
for image_size in image_sizes
|
||||||
|
]
|
||||||
|
num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
@ -20,9 +20,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BatchFeature
|
from ...image_processing_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput
|
from ...image_utils import ImageInput
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils import PreTokenizedInput, TextInput
|
from ...tokenization_utils import PreTokenizedInput, TextInput
|
||||||
from ...utils import TensorType
|
from ...utils import TensorType
|
||||||
from ..auto import AutoTokenizer
|
from ..auto import AutoTokenizer
|
||||||
@ -32,6 +34,7 @@ class AriaProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"max_image_size": 980,
|
"max_image_size": 980,
|
||||||
@ -121,10 +124,7 @@ class AriaProcessor(ProcessorMixin):
|
|||||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_inputs = self.image_processor(
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
images,
|
|
||||||
**output_kwargs["images_kwargs"],
|
|
||||||
)
|
|
||||||
# expand the image_token according to the num_crops and tokens per image
|
# expand the image_token according to the num_crops and tokens per image
|
||||||
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
|
tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
|
||||||
prompt_strings = []
|
prompt_strings = []
|
||||||
@ -138,11 +138,44 @@ class AriaProcessor(ProcessorMixin):
|
|||||||
prompt_strings = text
|
prompt_strings = text
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||||
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
|
||||||
|
num_image_patches = [
|
||||||
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||||||
|
for image_size in image_sizes
|
||||||
|
]
|
||||||
|
num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
@ -13,22 +13,23 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from transformers.processing_utils import (
|
import numpy as np
|
||||||
ImagesKwargs,
|
|
||||||
ProcessingKwargs,
|
|
||||||
ProcessorMixin,
|
|
||||||
Unpack,
|
|
||||||
)
|
|
||||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
||||||
|
|
||||||
from ...image_processing_utils import BatchFeature
|
from ...image_processing_utils import BatchFeature
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
ImageInput,
|
ImageInput,
|
||||||
make_flat_list_of_images,
|
make_flat_list_of_images,
|
||||||
)
|
)
|
||||||
|
from ...processing_utils import (
|
||||||
|
ImagesKwargs,
|
||||||
|
MultiModalData,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
)
|
||||||
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
|
||||||
|
|
||||||
class AyaVisionImagesKwargs(ImagesKwargs, total=False):
|
class AyaVisionImagesKwargs(ImagesKwargs, total=False):
|
||||||
@ -43,6 +44,7 @@ class AyaVisionProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding_side": "left",
|
"padding_side": "left",
|
||||||
"padding": True,
|
"padding": True,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"crop_to_patches": True,
|
"crop_to_patches": True,
|
||||||
@ -121,7 +123,6 @@ class AyaVisionProcessor(ProcessorMixin):
|
|||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
self.image_token = image_token
|
self.image_token = image_token
|
||||||
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
|
||||||
self.patch_size = patch_size * downsample_factor
|
self.patch_size = patch_size * downsample_factor
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
|
|
||||||
@ -131,6 +132,10 @@ class AyaVisionProcessor(ProcessorMixin):
|
|||||||
self.img_line_break_token = img_line_break_token
|
self.img_line_break_token = img_line_break_token
|
||||||
self.tile_token = tile_token
|
self.tile_token = tile_token
|
||||||
self.tile_global_token = tile_global_token
|
self.tile_global_token = tile_global_token
|
||||||
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.img_patch_token)
|
||||||
|
self.image_ids = tokenizer.convert_tokens_to_ids(
|
||||||
|
[img_patch_token, tile_token, tile_global_token, start_of_img_token, end_of_img_token]
|
||||||
|
)
|
||||||
|
|
||||||
def _prompt_split_image(self, num_patches):
|
def _prompt_split_image(self, num_patches):
|
||||||
"""
|
"""
|
||||||
@ -226,11 +231,49 @@ class AyaVisionProcessor(ProcessorMixin):
|
|||||||
text = processed_text
|
text = processed_text
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = AyaVisionProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
num_image_patches = [
|
||||||
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||||||
|
for image_size in image_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
token_per_patch = (self.img_size // self.patch_size) ** 2
|
||||||
|
num_image_tokens = [
|
||||||
|
token_per_patch + 3 + sum(token_per_patch + 1 for _ in range(1, num_patches))
|
||||||
|
for num_patches in num_image_patches
|
||||||
|
] # Add +3 and +1 for BOI/EOI and image tile tokens
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
@ -18,9 +18,18 @@ Processor class for Chameleon.
|
|||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput
|
from ...image_utils import ImageInput
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
|
from ...processing_utils import (
|
||||||
|
MultiModalData,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
TextKwargs,
|
||||||
|
Unpack,
|
||||||
|
_validate_images_text_input_order,
|
||||||
|
)
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
|
||||||
|
|
||||||
@ -34,6 +43,7 @@ class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
"return_for_text_completion": False,
|
"return_for_text_completion": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"common_kwargs": {
|
"common_kwargs": {
|
||||||
"return_tensors": "pt",
|
"return_tensors": "pt",
|
||||||
@ -73,6 +83,10 @@ class ChameleonProcessor(ProcessorMixin):
|
|||||||
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
|
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
|
||||||
) # fixed tokens for start and end, so can hardcode
|
) # fixed tokens for start and end, so can hardcode
|
||||||
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
|
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
|
||||||
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
|
self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token)
|
||||||
|
self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
|
||||||
|
self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id]
|
||||||
|
|
||||||
super().__init__(image_processor, tokenizer)
|
super().__init__(image_processor, tokenizer)
|
||||||
|
|
||||||
@ -141,14 +155,45 @@ class ChameleonProcessor(ProcessorMixin):
|
|||||||
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
|
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
|
||||||
prompt_strings.append(sample)
|
prompt_strings.append(sample)
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
image_inputs = {}
|
||||||
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
|
||||||
self._check_special_mm_tokens(prompt_strings, data, modalities=["image"])
|
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
|
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||||
|
|
||||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||||
|
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
# add 2 for BOI and EOI tokens
|
||||||
|
num_image_tokens = [self.image_seq_length + 2] * len(image_sizes)
|
||||||
|
num_image_patches = [1] * len(image_sizes)
|
||||||
|
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
|
@ -24,7 +24,7 @@ from typing import ClassVar, List, Optional, Union
|
|||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
|
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
|
||||||
from ...utils import is_torch_available
|
from ...utils import is_torch_available
|
||||||
|
|
||||||
@ -256,6 +256,25 @@ class ColPaliProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
return batch_query
|
return batch_query
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (List[List[str]], *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
Returns:
|
||||||
|
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
|
||||||
|
to a list containing the number of placeholder tokens required. If the model doesn't accept
|
||||||
|
a certain modality or no input sizes are provided, the dict value is set to an empty list.
|
||||||
|
"""
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
num_image_tokens = [self.image_seq_length] * len(image_sizes)
|
||||||
|
num_image_patches = [1] * len(image_sizes)
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
@ -16,10 +16,17 @@
|
|||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...image_processing_utils import BatchFeature
|
from ...image_processing_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput
|
from ...image_utils import ImageInput
|
||||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
|
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
|
from ...utils import is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from .image_processing_emu3 import smart_resize
|
||||||
|
|
||||||
|
|
||||||
class Emu3TextKwargs(TextKwargs, total=False):
|
class Emu3TextKwargs(TextKwargs, total=False):
|
||||||
@ -37,6 +44,7 @@ class Emu3ProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"return_for_image_generation": False,
|
"return_for_image_generation": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"ratio": "1:1",
|
"ratio": "1:1",
|
||||||
@ -166,7 +174,7 @@ class Emu3Processor(ProcessorMixin):
|
|||||||
|
|
||||||
image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'<placeholder>' * image_seq_length}{image_end_tokens}"
|
image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'<placeholder>' * image_seq_length}{image_end_tokens}"
|
||||||
sample = sample.replace(self.image_token, image_placeholder, 1)
|
sample = sample.replace(self.image_token, image_placeholder, 1)
|
||||||
sample = f"{self.bos_token}{sample}" # add BOS because PT tokenizer doesn't add it
|
sample = f"{self.bos_token}{sample}" # add BOS because GPT tokenizer doesn't add it
|
||||||
prompt_strings.append(sample)
|
prompt_strings.append(sample)
|
||||||
text = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
text = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
||||||
|
|
||||||
@ -179,12 +187,51 @@ class Emu3Processor(ProcessorMixin):
|
|||||||
|
|
||||||
# else just generate from text-only input, and we do no special treatment for text
|
# else just generate from text-only input, and we do no special treatment for text
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
data = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
self._check_special_mm_tokens(text, data, modalities=["image"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||||
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
data.update(**image_features)
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_features}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
num_image_tokens = []
|
||||||
|
for height, width in image_sizes:
|
||||||
|
height, width = smart_resize(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
self.image_processor.spatial_factor,
|
||||||
|
self.image_processor.min_pixels,
|
||||||
|
self.image_processor.max_pixels,
|
||||||
|
)
|
||||||
|
height = height // self.downsample_ratio
|
||||||
|
width = width // self.downsample_ratio
|
||||||
|
image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code
|
||||||
|
num_image_tokens.append(image_seq_length)
|
||||||
|
|
||||||
|
num_image_patches = [1] * len(image_sizes)
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def calculate_generate_size(self, ratio, image_area, spatial_factor):
|
def calculate_generate_size(self, ratio, image_area, spatial_factor):
|
||||||
width, height = map(int, ratio.split(":"))
|
width, height = map(int, ratio.split(":"))
|
||||||
|
@ -130,7 +130,7 @@ class FuyuModel(FuyuPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return output_embeddings
|
return output_embeddings
|
||||||
|
|
||||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs):
|
||||||
"""
|
"""
|
||||||
Encodes images into continuous embeddings that can be forwarded to the language model.
|
Encodes images into continuous embeddings that can be forwarded to the language model.
|
||||||
|
|
||||||
|
@ -22,7 +22,13 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ...image_utils import ImageInput
|
from ...image_utils import ImageInput
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
from ...processing_utils import (
|
||||||
|
MultiModalData,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
_validate_images_text_input_order,
|
||||||
|
)
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import is_torch_available, logging, requires_backends
|
from ...utils import is_torch_available, logging, requires_backends
|
||||||
from ...utils.import_utils import requires
|
from ...utils.import_utils import requires
|
||||||
@ -64,6 +70,7 @@ class FuyuProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
"return_token_type_ids": False,
|
"return_token_type_ids": False,
|
||||||
"return_length": False,
|
"return_length": False,
|
||||||
"verbose": True,
|
"verbose": True,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {},
|
"images_kwargs": {},
|
||||||
}
|
}
|
||||||
@ -355,6 +362,8 @@ class FuyuProcessor(ProcessorMixin):
|
|||||||
self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
|
self.max_position_embeddings = 16384 # TODO Can't derive this from model files: where to set it?
|
||||||
self.pad_token_id = 0
|
self.pad_token_id = 0
|
||||||
self.dummy_image_index = -1
|
self.dummy_image_index = -1
|
||||||
|
self.image_token_id = tokenizer.encode("|SPEAKER|", add_special_tokens=False)[1]
|
||||||
|
self.image_newline_id = tokenizer.encode("|NEWLINE|", add_special_tokens=False)[1]
|
||||||
|
|
||||||
def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool):
|
def _left_pad_inputs_with_attention_mask(self, model_inputs: List[Dict], return_attention_mask: bool):
|
||||||
max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
|
max_length_input_ids = max(entry["input_ids"].shape[1] for entry in model_inputs)
|
||||||
@ -403,6 +412,11 @@ class FuyuProcessor(ProcessorMixin):
|
|||||||
for key in batched_keys:
|
for key in batched_keys:
|
||||||
batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
|
batched_inputs[key] = torch.cat(batched_inputs[key], dim=0)
|
||||||
|
|
||||||
|
# Cast images to tensor as well, if only one image passed and no padding needed
|
||||||
|
# NOTE: vLLM expects all processor outputs to be a tensor
|
||||||
|
if len(batched_inputs["image_patches"]) == 1:
|
||||||
|
batched_inputs["image_patches"] = torch.cat(batched_inputs["image_patches"], dim=0)
|
||||||
|
|
||||||
return batched_inputs
|
return batched_inputs
|
||||||
|
|
||||||
def get_sample_encoding(
|
def get_sample_encoding(
|
||||||
@ -517,6 +531,7 @@ class FuyuProcessor(ProcessorMixin):
|
|||||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
|
||||||
if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
|
if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
|
||||||
raise ValueError("`return_attention_mask=False` is not supported for this model.")
|
raise ValueError("`return_attention_mask=False` is not supported for this model.")
|
||||||
@ -550,8 +565,6 @@ class FuyuProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
# --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
|
# --- Use self.tokenizer to get the ids of special tokens to insert into image ids ---
|
||||||
|
|
||||||
image_placeholder_id = self.tokenizer("|SPEAKER|", add_special_tokens=False)["input_ids"][1]
|
|
||||||
image_newline_id = self.tokenizer("|NEWLINE|", add_special_tokens=False)["input_ids"][1]
|
|
||||||
tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
|
tensor_batch_images = torch.stack([img[0] for img in batch_images]).unsqueeze(1)
|
||||||
|
|
||||||
# --- Use self.image_processor again to obtain the full token ids and batch inputs ---
|
# --- Use self.image_processor again to obtain the full token ids and batch inputs ---
|
||||||
@ -565,16 +578,63 @@ class FuyuProcessor(ProcessorMixin):
|
|||||||
scale_factors=[scale_factor],
|
scale_factors=[scale_factor],
|
||||||
image_unpadded_heights=torch.tensor([image_unpadded_height]),
|
image_unpadded_heights=torch.tensor([image_unpadded_height]),
|
||||||
image_unpadded_widths=torch.tensor([image_unpadded_width]),
|
image_unpadded_widths=torch.tensor([image_unpadded_width]),
|
||||||
image_placeholder_id=image_placeholder_id,
|
image_placeholder_id=self.image_token_id,
|
||||||
image_newline_id=image_newline_id,
|
image_newline_id=self.image_newline_id,
|
||||||
tensor_batch_images=tensor_batch_image.unsqueeze(0),
|
tensor_batch_images=tensor_batch_image.unsqueeze(0),
|
||||||
)
|
)
|
||||||
all_encodings.append(sample_encoding)
|
all_encodings.append(sample_encoding)
|
||||||
|
|
||||||
batch_encoding = self._left_pad_inputs_with_attention_mask(
|
batch_encoding = self._left_pad_inputs_with_attention_mask(
|
||||||
model_inputs=all_encodings, return_attention_mask=True
|
model_inputs=all_encodings, return_attention_mask=True
|
||||||
)
|
)
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
input_ids = batch_encoding["input_ids"]
|
||||||
|
mm_token_type_ids = torch.zeros_like(input_ids)
|
||||||
|
mm_token_type_ids[input_ids == self.image_token_id] = 1
|
||||||
|
mm_token_type_ids[input_ids == self.image_newline_id] = 1
|
||||||
|
batch_encoding["mm_token_type_ids"] = mm_token_type_ids
|
||||||
|
|
||||||
return FuyuBatchFeature(data=batch_encoding)
|
return FuyuBatchFeature(data=batch_encoding)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
size = kwargs.get("size", None) or self.image_processor.size
|
||||||
|
padded_height, padded_width = size["height"], size["width"]
|
||||||
|
|
||||||
|
num_image_tokens = []
|
||||||
|
num_image_patches = [1] * len(image_sizes)
|
||||||
|
for image_size in image_sizes:
|
||||||
|
height_scale_factor = padded_height / image_size[0]
|
||||||
|
width_scale_factor = padded_width / image_size[1]
|
||||||
|
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||||
|
|
||||||
|
# We can use torch here because Fuyu processor has hard dependency on torch
|
||||||
|
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||||
|
image_input=torch.zeros(1, 1, 3, padded_height, padded_width),
|
||||||
|
image_present=torch.ones(1, 1, 1),
|
||||||
|
image_unpadded_h=torch.tensor([[int(image_size[0] * optimal_scale_factor)]]),
|
||||||
|
image_unpadded_w=torch.tensor([[int(image_size[1] * optimal_scale_factor)]]),
|
||||||
|
image_placeholder_id=0, # dummy ids, we can be sure `id=0` is never out-of-range
|
||||||
|
image_newline_id=0,
|
||||||
|
variable_sized=True,
|
||||||
|
)
|
||||||
|
num_image_tokens.append(model_image_input["image_input_ids"][0][0].shape[-1])
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def post_process_box_coordinates(self, outputs, target_sizes=None):
|
def post_process_box_coordinates(self, outputs, target_sizes=None):
|
||||||
"""
|
"""
|
||||||
Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.
|
Transforms raw coordinates detected by [`FuyuForCausalLM`] to the original images' coordinate space.
|
||||||
|
@ -20,7 +20,7 @@ import numpy as np
|
|||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, make_nested_list_of_images
|
from ...image_utils import ImageInput, make_nested_list_of_images
|
||||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import to_py_obj
|
from ...utils import to_py_obj
|
||||||
|
|
||||||
@ -38,6 +38,7 @@ class Gemma3ProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": True,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"do_pan_and_scan": False,
|
"do_pan_and_scan": False,
|
||||||
@ -137,17 +138,42 @@ class Gemma3Processor(ProcessorMixin):
|
|||||||
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
text = [prompt.replace(self.boi_token, self.full_image_sequence) for prompt in text]
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"], return_tensors="np")
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
# Add token type ids manually, as tokenizer can't do arbitrary position token types
|
||||||
array_ids = text_inputs["input_ids"]
|
if return_mm_token_type_ids:
|
||||||
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
mm_token_type_ids = np.zeros_like(array_ids)
|
||||||
text_inputs = {k: v.tolist() for k, v in text_inputs.items()} # in case user requested list inputs
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
text_inputs["token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
# NOTE: no image cropping supported yet
|
||||||
|
num_image_tokens = [self.image_seq_length] * len(image_sizes)
|
||||||
|
num_image_patches = [1] * len(image_sizes)
|
||||||
|
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -491,5 +491,33 @@ class GotOcr2ImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return processed_images
|
return processed_images
|
||||||
|
|
||||||
|
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number patches for a given image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
images_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the image processor.
|
||||||
|
Returns:
|
||||||
|
`int`: Number of patches per image.
|
||||||
|
"""
|
||||||
|
min_patches = images_kwargs.get("min_patches", None) or self.min_patches
|
||||||
|
max_patches = images_kwargs.get("max_patches", None) or self.max_patches
|
||||||
|
patch_size = images_kwargs.get("size", None) or self.size
|
||||||
|
crop_to_patches = images_kwargs.get("crop_to_patches", None) or self.crop_to_patches
|
||||||
|
|
||||||
|
num_patches = 1
|
||||||
|
if crop_to_patches and max_patches > 1:
|
||||||
|
num_columns, num_rows = get_optimal_tiled_canvas(
|
||||||
|
(height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches
|
||||||
|
)
|
||||||
|
num_patches += num_columns * num_rows
|
||||||
|
|
||||||
|
return num_patches
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["GotOcr2ImageProcessor"]
|
__all__ = ["GotOcr2ImageProcessor"]
|
||||||
|
@ -228,5 +228,33 @@ class GotOcr2ImageProcessorFast(BaseImageProcessorFast):
|
|||||||
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
|
data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_number_of_image_tokens(self, height: int, width: int, images_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number patches for a given image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
images_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the image processor.
|
||||||
|
Returns:
|
||||||
|
`int`: Number of patches per image.
|
||||||
|
"""
|
||||||
|
min_patches = images_kwargs.get("min_patches", None) or self.min_patches
|
||||||
|
max_patches = images_kwargs.get("max_patches", None) or self.max_patches
|
||||||
|
patch_size = images_kwargs.get("size", None) or self.size
|
||||||
|
crop_to_patches = images_kwargs.get("crop_to_patches", None) or self.crop_to_patches
|
||||||
|
|
||||||
|
num_patches = 1
|
||||||
|
if crop_to_patches and max_patches > 1:
|
||||||
|
num_columns, num_rows = get_optimal_tiled_canvas(
|
||||||
|
(height, width), (patch_size["height"], patch_size["width"]), min_patches, max_patches
|
||||||
|
)
|
||||||
|
num_patches += num_columns * num_rows
|
||||||
|
|
||||||
|
return num_patches
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["GotOcr2ImageProcessorFast"]
|
__all__ = ["GotOcr2ImageProcessorFast"]
|
||||||
|
@ -850,5 +850,46 @@ class Idefics3ImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
|
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number of image patches for a given image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
images_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the image processor.
|
||||||
|
Returns:
|
||||||
|
`int`: Number of patches per image.
|
||||||
|
"""
|
||||||
|
do_image_splitting = images_kwargs.get("do_image_splitting", None) or self.do_image_splitting
|
||||||
|
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
|
||||||
|
size = images_kwargs.get("size", None) or self.size
|
||||||
|
|
||||||
|
if do_image_splitting:
|
||||||
|
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
|
||||||
|
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096)
|
||||||
|
aspect_ratio = width / height
|
||||||
|
|
||||||
|
if width >= height:
|
||||||
|
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
|
||||||
|
resized_height = int(width / aspect_ratio)
|
||||||
|
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
|
||||||
|
elif height > width:
|
||||||
|
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
|
||||||
|
resized_width = int(height * aspect_ratio)
|
||||||
|
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
|
||||||
|
|
||||||
|
max_height = max_width = max_image_size["longest_edge"]
|
||||||
|
if resized_height > max_height or resized_width > max_width:
|
||||||
|
# Calculate the number of splits
|
||||||
|
num_rows = math.ceil(resized_height / max_height)
|
||||||
|
num_cols = math.ceil(resized_width / max_width)
|
||||||
|
num_patches = num_rows * num_cols + 1
|
||||||
|
|
||||||
|
return num_patches
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Idefics3ImageProcessor"]
|
__all__ = ["Idefics3ImageProcessor"]
|
||||||
|
@ -16,13 +16,16 @@
|
|||||||
Processor class for Idefics3.
|
Processor class for Idefics3.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
import re
|
import re
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
|
from ...tokenization_utils_base import AddedToken, BatchEncoding, TextInput
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@ -98,6 +101,7 @@ class Idefics3ProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
"add_special_tokens": True,
|
"add_special_tokens": True,
|
||||||
"padding": False,
|
"padding": False,
|
||||||
"is_split_into_words": False,
|
"is_split_into_words": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"return_row_col_info": True,
|
"return_row_col_info": True,
|
||||||
@ -146,6 +150,12 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True).content
|
self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True).content
|
||||||
self.global_image_tag = "<global-img>" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341
|
self.global_image_tag = "<global-img>" # https://github.com/huggingface/transformers/pull/32473/files/8063e5e17362571b693f1db95167f5443a3be1b2#r1734825341
|
||||||
self.image_seq_len = image_seq_len
|
self.image_seq_len = image_seq_len
|
||||||
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
|
self.fake_image_token_id = tokenizer.convert_tokens_to_ids(self.fake_image_token)
|
||||||
|
self.global_image_token_id = tokenizer.convert_tokens_to_ids(self.global_image_tag)
|
||||||
|
self.row_col_ids = [
|
||||||
|
tokenizer.convert_tokens_to_ids(f"<row_{i + 1}_col_{j + 1}>") for i in range(6) for j in range(6)
|
||||||
|
]
|
||||||
|
|
||||||
# This regex matches one or more occurrences of <global-img> tags (optionally surrounded by newline characters)
|
# This regex matches one or more occurrences of <global-img> tags (optionally surrounded by newline characters)
|
||||||
# or <row_x_col_y> tags (where x and y are digits, also optionally surrounded by newline characters).
|
# or <row_x_col_y> tags (where x and y are digits, also optionally surrounded by newline characters).
|
||||||
@ -241,6 +251,7 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
|
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
|
||||||
n_images_in_text = []
|
n_images_in_text = []
|
||||||
@ -302,9 +313,11 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
global_img_token = self.global_image_tag
|
global_img_token = self.global_image_tag
|
||||||
|
|
||||||
prompt_strings = []
|
prompt_strings = []
|
||||||
|
batch_image_seq_lengths = []
|
||||||
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
|
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
|
||||||
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
|
# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
|
||||||
image_prompt_strings = []
|
image_prompt_strings = []
|
||||||
|
image_seq_lengths = []
|
||||||
for n_rows, n_cols in zip(sample_rows, sample_cols):
|
for n_rows, n_cols in zip(sample_rows, sample_cols):
|
||||||
image_prompt_string = get_image_prompt_string(
|
image_prompt_string = get_image_prompt_string(
|
||||||
n_rows,
|
n_rows,
|
||||||
@ -314,8 +327,12 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
fake_token_around_image=fake_image_token,
|
fake_token_around_image=fake_image_token,
|
||||||
global_img_token=global_img_token,
|
global_img_token=global_img_token,
|
||||||
)
|
)
|
||||||
|
# Add +2 and +3 for special BOI/EOI/fake_image_wrapper tokens
|
||||||
|
row_length = (self.image_seq_len + 2) * n_cols + 1
|
||||||
|
image_seq_lengths.append((self.image_seq_len + 3) + row_length * n_rows)
|
||||||
image_prompt_strings.append(image_prompt_string)
|
image_prompt_strings.append(image_prompt_string)
|
||||||
|
|
||||||
|
batch_image_seq_lengths.append(image_seq_lengths)
|
||||||
split_sample = sample.split(image_token)
|
split_sample = sample.split(image_token)
|
||||||
if len(split_sample) == 0:
|
if len(split_sample) == 0:
|
||||||
raise ValueError("The image token should be present in the text.")
|
raise ValueError("The image token should be present in the text.")
|
||||||
@ -338,7 +355,59 @@ class Idefics3Processor(ProcessorMixin):
|
|||||||
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
|
||||||
inputs.update(text_inputs)
|
inputs.update(text_inputs)
|
||||||
|
|
||||||
return BatchFeature(inputs, tensor_type=return_tensors)
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(array_ids)
|
||||||
|
for i, seq_lengths in enumerate(batch_image_seq_lengths):
|
||||||
|
image_start_positions = np.where(array_ids[i] == self.fake_image_token_id)[0]
|
||||||
|
j = 0
|
||||||
|
for seq_len in seq_lengths:
|
||||||
|
if j >= len(image_start_positions):
|
||||||
|
break
|
||||||
|
start = image_start_positions[j]
|
||||||
|
end = start + seq_len
|
||||||
|
mm_token_type_ids[i, start:end] = 1
|
||||||
|
j = np.searchsorted(image_start_positions, end)
|
||||||
|
|
||||||
|
inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
|
return BatchFeature(data=inputs, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = Idefics3ProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
num_image_patches = [
|
||||||
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||||||
|
for image_size in image_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
base_image_length = self.image_seq_len + 3
|
||||||
|
col_length = self.image_seq_len + 2
|
||||||
|
num_image_tokens = []
|
||||||
|
|
||||||
|
for num_patches in num_image_patches:
|
||||||
|
num_cols = num_rows = int(math.sqrt(num_patches - 1))
|
||||||
|
row_length = col_length * num_cols + 1
|
||||||
|
num_image_tokens.append(base_image_length + (row_length * num_rows))
|
||||||
|
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -13,25 +13,24 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.processing_utils import (
|
|
||||||
ImagesKwargs,
|
|
||||||
ProcessingKwargs,
|
|
||||||
ProcessorMixin,
|
|
||||||
Unpack,
|
|
||||||
)
|
|
||||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
|
||||||
|
|
||||||
from ...image_processing_utils import BatchFeature
|
from ...image_processing_utils import BatchFeature
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
ImageInput,
|
ImageInput,
|
||||||
concatenate_list,
|
concatenate_list,
|
||||||
make_flat_list_of_images,
|
make_flat_list_of_images,
|
||||||
)
|
)
|
||||||
|
from ...processing_utils import (
|
||||||
|
ImagesKwargs,
|
||||||
|
MultiModalData,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
)
|
||||||
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos
|
from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos
|
||||||
|
|
||||||
|
|
||||||
@ -46,6 +45,7 @@ class InternVLProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding_side": "left",
|
"padding_side": "left",
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"crop_to_patches": True,
|
"crop_to_patches": True,
|
||||||
@ -94,9 +94,12 @@ class InternVLProcessor(ProcessorMixin):
|
|||||||
self.image_seq_length = image_seq_length
|
self.image_seq_length = image_seq_length
|
||||||
self.start_image_token = tokenizer.start_image_token
|
self.start_image_token = tokenizer.start_image_token
|
||||||
self.end_image_token = tokenizer.end_image_token
|
self.end_image_token = tokenizer.end_image_token
|
||||||
|
self.start_image_token_id = tokenizer.start_image_token_id
|
||||||
|
self.end_image_token_id = tokenizer.end_image_token_id
|
||||||
self.image_token = tokenizer.context_image_token
|
self.image_token = tokenizer.context_image_token
|
||||||
self.video_token = tokenizer.video_token
|
self.video_token = tokenizer.video_token
|
||||||
self.image_token_id = tokenizer.context_image_token_id
|
self.image_token_id = tokenizer.context_image_token_id
|
||||||
|
self.image_ids = [self.image_token_id, self.start_image_token_id, self.end_image_token_id]
|
||||||
|
|
||||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
|
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
|
||||||
|
|
||||||
@ -261,11 +264,46 @@ class InternVLProcessor(ProcessorMixin):
|
|||||||
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
|
image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = InternVLProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
num_image_patches = [
|
||||||
|
self.image_processor.get_number_of_image_tokens(*image_size, images_kwargs)
|
||||||
|
for image_size in image_sizes
|
||||||
|
]
|
||||||
|
# Add 2 for BOI and EOI tokens
|
||||||
|
num_image_tokens = [2 + (self.image_seq_length * num_patches) for num_patches in num_image_patches]
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def sample_indices_fn(
|
def sample_indices_fn(
|
||||||
self, metadata: VideoMetadata, num_frames: Optional[int] = None, initial_shift: Union[bool, float, int] = True
|
self, metadata: VideoMetadata, num_frames: Optional[int] = None, initial_shift: Union[bool, float, int] = True
|
||||||
):
|
):
|
||||||
|
@ -18,9 +18,17 @@ Processor class for Llava.
|
|||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
from ...processing_utils import (
|
||||||
|
MultiModalData,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
_validate_images_text_input_order,
|
||||||
|
)
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@ -30,9 +38,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
|
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
|
||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {"padding": False, "return_mm_token_type_ids": False},
|
||||||
"padding": False,
|
|
||||||
},
|
|
||||||
"images_kwargs": {},
|
"images_kwargs": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,11 +95,7 @@ class LlavaProcessor(ProcessorMixin):
|
|||||||
self.num_additional_image_tokens = num_additional_image_tokens
|
self.num_additional_image_tokens = num_additional_image_tokens
|
||||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||||
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
|
||||||
self.image_token_id = (
|
self.image_token_id = tokenizer.encode(self.image_token, add_special_tokens=False)[0]
|
||||||
tokenizer.image_token_id
|
|
||||||
if getattr(tokenizer, "image_token_id", None)
|
|
||||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
|
||||||
)
|
|
||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -174,10 +176,49 @@ class LlavaProcessor(ProcessorMixin):
|
|||||||
prompt_strings.append(sample)
|
prompt_strings.append(sample)
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||||
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = LlavaProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
crop_size = images_kwargs.get("crop_size", None) or self.image_processor.crop_size
|
||||||
|
resized_height, resized_width = crop_size["height"], crop_size["width"]
|
||||||
|
|
||||||
|
num_image_tokens = (resized_height // self.patch_size) * (resized_width // self.patch_size)
|
||||||
|
num_image_tokens += self.num_additional_image_tokens
|
||||||
|
if self.vision_feature_select_strategy == "default":
|
||||||
|
num_image_tokens -= 1
|
||||||
|
|
||||||
|
num_image_tokens = [num_image_tokens] * len(image_sizes)
|
||||||
|
num_image_patches = [1] * len(image_sizes)
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -18,10 +18,18 @@ Processor class for LLaVa-NeXT.
|
|||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_processing_utils import select_best_resolution
|
from ...image_processing_utils import select_best_resolution
|
||||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
from ...processing_utils import (
|
||||||
|
MultiModalData,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
_validate_images_text_input_order,
|
||||||
|
)
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
@ -33,6 +41,7 @@ class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"do_pad": True,
|
"do_pad": True,
|
||||||
@ -172,9 +181,16 @@ class LlavaNextProcessor(ProcessorMixin):
|
|||||||
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
||||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
||||||
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
||||||
@ -219,6 +235,48 @@ class LlavaNextProcessor(ProcessorMixin):
|
|||||||
newline_features = current_height
|
newline_features = current_height
|
||||||
return (unpadded_features, newline_features)
|
return (unpadded_features, newline_features)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
Args:
|
||||||
|
image_sizes (List[List[str]], *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
video_sizes (List[List[str]], *optional*):
|
||||||
|
The input sizes formatted as (num_frames, height, width) per each video.
|
||||||
|
audio_lengths (List[int], *optional*):
|
||||||
|
The input length formatted as per each audio.
|
||||||
|
Returns:
|
||||||
|
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
|
||||||
|
to a list containing the number of placeholder tokens required. If the model doesn't accept
|
||||||
|
a certain modality or no input sizes are provided, the dict value is set to an empty list.
|
||||||
|
"""
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = LlavaNextProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
size = images_kwargs.get("size", None) or self.image_processor.size
|
||||||
|
size = (
|
||||||
|
(size["shortest_edge"], size["shortest_edge"])
|
||||||
|
if "shortest_edge" in size
|
||||||
|
else (min(size["height"], size["width"]), min(size["height"], size["width"]))
|
||||||
|
)
|
||||||
|
processed_height, processed_width = size
|
||||||
|
|
||||||
|
batch_num_image_tokens = []
|
||||||
|
num_image_patches = [1] * len(image_sizes) # llava-next doesn't batch pixels as Idefics, thus `1` patch`
|
||||||
|
for image_size in image_sizes:
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
num_image_tokens = self._get_number_of_features(
|
||||||
|
orig_height, orig_width, processed_height, processed_width
|
||||||
|
)
|
||||||
|
if self.vision_feature_select_strategy == "default":
|
||||||
|
num_image_tokens -= 1
|
||||||
|
batch_num_image_tokens.append(num_image_tokens)
|
||||||
|
vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -24,7 +24,7 @@ import numpy as np
|
|||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_processing_utils import select_best_resolution
|
from ...image_processing_utils import select_best_resolution
|
||||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...video_utils import VideoInput
|
from ...video_utils import VideoInput
|
||||||
@ -38,6 +38,7 @@ class LlavaOnevisionProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"image_kwargs": {},
|
"image_kwargs": {},
|
||||||
"videos_kwargs": {},
|
"videos_kwargs": {},
|
||||||
@ -196,9 +197,16 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
|||||||
text = [sample.replace(self.video_token, self.video_token * num_video_tokens) for sample in text]
|
text = [sample.replace(self.video_token, self.video_token * num_video_tokens) for sample in text]
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
def _expand_image_tokens(
|
def _expand_image_tokens(
|
||||||
@ -285,6 +293,48 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
|||||||
|
|
||||||
return (unpadded_features, newline_features)
|
return (unpadded_features, newline_features)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
Args:
|
||||||
|
image_sizes (List[List[str]], *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
video_sizes (List[List[str]], *optional*):
|
||||||
|
The input sizes formatted as (num_frames, height, width) per each video.
|
||||||
|
audio_lengths (List[int], *optional*):
|
||||||
|
The input length formatted as per each audio.
|
||||||
|
Returns:
|
||||||
|
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
|
||||||
|
to a list containing the number of placeholder tokens required. If the model doesn't accept
|
||||||
|
a certain modality or no input sizes are provided, the dict value is set to an empty list.
|
||||||
|
"""
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = LlavaOnevisionProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
size = images_kwargs.get("size", None) or self.image_processor.size
|
||||||
|
size = (
|
||||||
|
(size["shortest_edge"], size["shortest_edge"])
|
||||||
|
if "shortest_edge" in size
|
||||||
|
else (min(size["height"], size["width"]), min(size["height"], size["width"]))
|
||||||
|
)
|
||||||
|
processed_height, processed_width = size
|
||||||
|
|
||||||
|
batch_num_image_tokens = []
|
||||||
|
num_image_patches = [1] * len(image_sizes) # llava-ov doesn't batch pixels as Idefics, thus `1` patch`
|
||||||
|
for image_size in image_sizes:
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
num_image_tokens = self._get_number_of_features(
|
||||||
|
orig_height, orig_width, processed_height, processed_width
|
||||||
|
)
|
||||||
|
if self.vision_feature_select_strategy == "default":
|
||||||
|
num_image_tokens -= 1
|
||||||
|
batch_num_image_tokens.append(num_image_tokens)
|
||||||
|
vision_data.update({"num_image_tokens": batch_num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -18,10 +18,13 @@ Processor class for PaliGemma.
|
|||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
|
from ...image_utils import ImageInput, is_valid_image, make_flat_list_of_images
|
||||||
from ...processing_utils import (
|
from ...processing_utils import (
|
||||||
ImagesKwargs,
|
ImagesKwargs,
|
||||||
|
MultiModalData,
|
||||||
ProcessingKwargs,
|
ProcessingKwargs,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
TextKwargs,
|
TextKwargs,
|
||||||
@ -56,6 +59,7 @@ class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {
|
"images_kwargs": {
|
||||||
"data_format": "channels_first",
|
"data_format": "channels_first",
|
||||||
@ -299,6 +303,7 @@ class PaliGemmaProcessor(ProcessorMixin):
|
|||||||
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
|
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
input_strings,
|
input_strings,
|
||||||
text_pair=suffix,
|
text_pair=suffix,
|
||||||
@ -312,8 +317,34 @@ class PaliGemmaProcessor(ProcessorMixin):
|
|||||||
if return_token_type_ids:
|
if return_token_type_ids:
|
||||||
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
|
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
|
||||||
return_data.update({"labels": labels})
|
return_data.update({"labels": labels})
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(return_data["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(return_data["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
return_data["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data=return_data, tensor_type=return_tensors)
|
return BatchFeature(data=return_data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (List[List[str]], *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
Returns:
|
||||||
|
Dict[str, List[int]]: A dictionary mapping each modality ("image", "video", "audio")
|
||||||
|
to a list containing the number of placeholder tokens required. If the model doesn't accept
|
||||||
|
a certain modality or no input sizes are provided, the dict value is set to an empty list.
|
||||||
|
"""
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
num_image_tokens = [self.image_seq_length] * len(image_sizes)
|
||||||
|
num_image_patches = [1] * len(image_sizes)
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -18,11 +18,23 @@ Processor class for Pixtral.
|
|||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
from ...processing_utils import (
|
||||||
|
MultiModalData,
|
||||||
|
ProcessingKwargs,
|
||||||
|
ProcessorMixin,
|
||||||
|
Unpack,
|
||||||
|
_validate_images_text_input_order,
|
||||||
|
)
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import logging
|
from ...utils import is_vision_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from .image_processing_pixtral import get_resize_output_image_size
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -32,6 +44,7 @@ class PixtralProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"images_kwargs": {},
|
"images_kwargs": {},
|
||||||
"common_kwargs": {
|
"common_kwargs": {
|
||||||
@ -106,6 +119,10 @@ class PixtralProcessor(ProcessorMixin):
|
|||||||
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
self.image_break_token = image_break_token
|
self.image_break_token = image_break_token
|
||||||
self.image_end_token = image_end_token
|
self.image_end_token = image_end_token
|
||||||
|
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||||
|
self.image_break_token_id = tokenizer.convert_tokens_to_ids(self.image_break_token)
|
||||||
|
self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
|
||||||
|
self.image_ids = [self.image_token_id, self.image_break_token_id, self.image_end_token_id]
|
||||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -213,10 +230,54 @@ class PixtralProcessor(ProcessorMixin):
|
|||||||
prompt_strings.append(sample)
|
prompt_strings.append(sample)
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||||
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = PixtralProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
size = images_kwargs.get("size", None) or self.image_processor.size
|
||||||
|
patch_size = self.patch_size * self.spatial_merge_size
|
||||||
|
|
||||||
|
num_image_tokens = []
|
||||||
|
for height, width in image_sizes:
|
||||||
|
resized_height, resized_width = get_resize_output_image_size(
|
||||||
|
image=np.zeros((height, width, 3)),
|
||||||
|
size=(size["longest_edge"], size["longest_edge"]),
|
||||||
|
patch_size=(patch_size, patch_size),
|
||||||
|
)
|
||||||
|
num_height_tokens = resized_height // patch_size
|
||||||
|
num_width_tokens = resized_width // patch_size
|
||||||
|
num_image_tokens.append((num_width_tokens + 1) * num_height_tokens)
|
||||||
|
|
||||||
|
num_image_patches = [1] * len(image_sizes)
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -48,7 +49,7 @@ from ...configuration_utils import PretrainedConfig
|
|||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput
|
from ...image_utils import ImageInput
|
||||||
from ...modeling_flash_attention_utils import is_flash_attn_available
|
from ...modeling_flash_attention_utils import is_flash_attn_available
|
||||||
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
|
from ...processing_utils import MultiModalData, ProcessingKwargs, Unpack, VideosKwargs
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import is_torchdynamo_compiling, logging
|
from ...utils import is_torchdynamo_compiling, logging
|
||||||
from ...video_utils import VideoInput
|
from ...video_utils import VideoInput
|
||||||
@ -925,6 +926,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"videos_kwargs": {"fps": 2.0},
|
"videos_kwargs": {"fps": 2.0},
|
||||||
}
|
}
|
||||||
@ -1050,11 +1052,56 @@ class Qwen2_5_VLProcessor(Qwen2VLProcessor):
|
|||||||
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
video_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (num_frames, height, width) per each video.
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
|
||||||
|
|
||||||
|
num_image_patches = [
|
||||||
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||||||
|
for image_size in image_sizes
|
||||||
|
]
|
||||||
|
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
if video_sizes is not None:
|
||||||
|
videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
|
||||||
|
videos_kwargs.update(kwargs)
|
||||||
|
num_video_patches = [
|
||||||
|
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
|
||||||
|
for video_size in video_sizes
|
||||||
|
]
|
||||||
|
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
|
||||||
|
vision_data["num_video_tokens"] = num_video_tokens
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Qwen2_5_VLConfig",
|
"Qwen2_5_VLConfig",
|
||||||
|
@ -25,9 +25,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput
|
from ...image_utils import ImageInput
|
||||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...video_utils import VideoInput
|
from ...video_utils import VideoInput
|
||||||
|
|
||||||
@ -50,6 +52,7 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
"videos_kwargs": {"fps": 2.0},
|
"videos_kwargs": {"fps": 2.0},
|
||||||
}
|
}
|
||||||
@ -188,11 +191,56 @@ class Qwen2_5_VLProcessor(ProcessorMixin):
|
|||||||
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
video_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (num_frames, height, width) per each video.
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
|
||||||
|
|
||||||
|
num_image_patches = [
|
||||||
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||||||
|
for image_size in image_sizes
|
||||||
|
]
|
||||||
|
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
if video_sizes is not None:
|
||||||
|
videos_kwargs = Qwen2_5_VLProcessorKwargs._defaults.get("videos_kwargs", {})
|
||||||
|
videos_kwargs.update(kwargs)
|
||||||
|
num_video_patches = [
|
||||||
|
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
|
||||||
|
for video_size in video_sizes
|
||||||
|
]
|
||||||
|
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
|
||||||
|
vision_data["num_video_tokens"] = num_video_tokens
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
@ -490,5 +490,31 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number of image patches for a given image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
images_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the image processor.
|
||||||
|
Returns:
|
||||||
|
`int`: Number of image patches per image.
|
||||||
|
"""
|
||||||
|
min_pixels = images_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
|
||||||
|
max_pixels = images_kwargs.get("max_pixels", None) or self.size["longest_edge"]
|
||||||
|
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
|
||||||
|
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
|
||||||
|
|
||||||
|
factor = patch_size * merge_size
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||||
|
return grid_h * grid_w
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Qwen2VLImageProcessor"]
|
__all__ = ["Qwen2VLImageProcessor"]
|
||||||
|
@ -402,5 +402,31 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast):
|
|||||||
|
|
||||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number of image patches for a given image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
images_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the image processor.
|
||||||
|
Returns:
|
||||||
|
`int`: Number of image patches per image.
|
||||||
|
"""
|
||||||
|
min_pixels = images_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
|
||||||
|
max_pixels = images_kwargs.get("max_pixels", None) or self.size["longest_edge"]
|
||||||
|
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
|
||||||
|
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
|
||||||
|
|
||||||
|
factor = patch_size * merge_size
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||||
|
return grid_h * grid_w
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Qwen2VLImageProcessorFast"]
|
__all__ = ["Qwen2VLImageProcessorFast"]
|
||||||
|
@ -23,9 +23,11 @@ Processor class for Qwen2-VL.
|
|||||||
|
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ...feature_extraction_utils import BatchFeature
|
from ...feature_extraction_utils import BatchFeature
|
||||||
from ...image_utils import ImageInput
|
from ...image_utils import ImageInput
|
||||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...video_utils import VideoInput
|
from ...video_utils import VideoInput
|
||||||
@ -47,6 +49,7 @@ class Qwen2VLProcessorKwargs(ProcessingKwargs, total=False):
|
|||||||
_defaults = {
|
_defaults = {
|
||||||
"text_kwargs": {
|
"text_kwargs": {
|
||||||
"padding": False,
|
"padding": False,
|
||||||
|
"return_mm_token_type_ids": False,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -172,10 +175,56 @@ class Qwen2VLProcessor(ProcessorMixin):
|
|||||||
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||||
|
|
||||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
|
||||||
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None)
|
||||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||||
|
|
||||||
|
if return_mm_token_type_ids:
|
||||||
|
array_ids = np.array(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
|
||||||
|
mm_token_type_ids[array_ids == self.image_token_id] = 1
|
||||||
|
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
|
||||||
|
|
||||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||||
|
Args:
|
||||||
|
image_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (height, width) per each image.
|
||||||
|
video_sizes (`List[List[int]]`, *optional*):
|
||||||
|
The input sizes formatted as (num_frames, height, width) per each video.
|
||||||
|
Returns:
|
||||||
|
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||||
|
input modalities, along with other useful data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vision_data = {}
|
||||||
|
if image_sizes is not None:
|
||||||
|
images_kwargs = Qwen2VLProcessorKwargs._defaults.get("images_kwargs", {})
|
||||||
|
images_kwargs.update(kwargs)
|
||||||
|
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
|
||||||
|
|
||||||
|
num_image_patches = [
|
||||||
|
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||||||
|
for image_size in image_sizes
|
||||||
|
]
|
||||||
|
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
|
||||||
|
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||||
|
|
||||||
|
if video_sizes is not None:
|
||||||
|
videos_kwargs = Qwen2VLProcessorKwargs._defaults.get("videos_kwargs", {})
|
||||||
|
videos_kwargs.update(kwargs)
|
||||||
|
num_video_patches = [
|
||||||
|
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
|
||||||
|
for video_size in video_sizes
|
||||||
|
]
|
||||||
|
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
|
||||||
|
vision_data["num_video_tokens"] = num_video_tokens
|
||||||
|
|
||||||
|
return MultiModalData(**vision_data)
|
||||||
|
|
||||||
def batch_decode(self, *args, **kwargs):
|
def batch_decode(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||||
|
@ -204,5 +204,35 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
|||||||
tensor_type=return_tensors,
|
tensor_type=return_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_num_of_video_patches(self, num_frames: int, height: int, width: int, videos_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number of video patches a given video size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_frames (`int`):
|
||||||
|
Number of frames in the input video.
|
||||||
|
height (`int`):
|
||||||
|
Height of the input video.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input video.
|
||||||
|
videos_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the video processor.
|
||||||
|
Returns:
|
||||||
|
`Tuple(int, int)`: Number of placeholder tokens required and number of patches per image.
|
||||||
|
"""
|
||||||
|
min_pixels = videos_kwargs.get("min_pixels", None) or self.size["shortest_edge"]
|
||||||
|
max_pixels = videos_kwargs.get("max_pixels", None) or self.size["longest_edge"]
|
||||||
|
patch_size = videos_kwargs.get("patch_size", None) or self.patch_size
|
||||||
|
merge_size = videos_kwargs.get("merge_size", None) or self.merge_size
|
||||||
|
temporal_patch_size = videos_kwargs.get("temporal_patch_size", None) or self.temporal_patch_size
|
||||||
|
|
||||||
|
factor = patch_size * merge_size
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
height, width, factor, min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||||
|
grid_t = num_frames // temporal_patch_size
|
||||||
|
return grid_t * grid_h * grid_w
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Qwen2VLVideoProcessor"]
|
__all__ = ["Qwen2VLVideoProcessor"]
|
||||||
|
@ -847,5 +847,46 @@ class SmolVLMImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
return encoding
|
return encoding
|
||||||
|
|
||||||
|
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||||
|
"""
|
||||||
|
A utility that returns number of image patches for a given image size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height (`int`):
|
||||||
|
Height of the input image.
|
||||||
|
width (`int`):
|
||||||
|
Width of the input image.
|
||||||
|
images_kwargs (`dict`, *optional*)
|
||||||
|
Any kwargs to override defaults of the image processor.
|
||||||
|
Returns:
|
||||||
|
`int`: Number of patches per image.
|
||||||
|
"""
|
||||||
|
do_image_splitting = images_kwargs.get("do_image_splitting", None) or self.do_image_splitting
|
||||||
|
max_image_size = images_kwargs.get("max_image_size", None) or self.max_image_size
|
||||||
|
size = images_kwargs.get("size", None) or self.size
|
||||||
|
|
||||||
|
if do_image_splitting:
|
||||||
|
height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
|
||||||
|
height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=4096)
|
||||||
|
aspect_ratio = width / height
|
||||||
|
|
||||||
|
if width >= height:
|
||||||
|
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
|
||||||
|
resized_height = int(width / aspect_ratio)
|
||||||
|
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
|
||||||
|
elif height > width:
|
||||||
|
resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
|
||||||
|
resized_width = int(height * aspect_ratio)
|
||||||
|
resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
|
||||||
|
|
||||||
|
max_height = max_width = max_image_size["longest_edge"]
|
||||||
|
if resized_height > max_height or resized_width > max_width:
|
||||||
|
# Calculate the number of splits
|
||||||
|
num_rows = math.ceil(resized_height / max_height)
|
||||||
|
num_cols = math.ceil(resized_width / max_width)
|
||||||
|
num_patches = num_rows * num_cols + 1
|
||||||
|
|
||||||
|
return num_patches
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["SmolVLMImageProcessor"]
|
__all__ = ["SmolVLMImageProcessor"]
|
||||||
|
@ -22,6 +22,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||||
|
|
||||||
@ -120,6 +121,8 @@ class TextKwargs(TypedDict, total=False):
|
|||||||
Whether or not to print more information and warnings.
|
Whether or not to print more information and warnings.
|
||||||
padding_side (`str`, *optional*):
|
padding_side (`str`, *optional*):
|
||||||
The side on which padding will be applied.
|
The side on which padding will be applied.
|
||||||
|
return_mm_token_type_ids (`bool`, *optional*):
|
||||||
|
Whether to return multimodal token type ids indicating mm placeholder token positions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
|
text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
|
||||||
@ -140,6 +143,7 @@ class TextKwargs(TypedDict, total=False):
|
|||||||
return_length: Optional[bool]
|
return_length: Optional[bool]
|
||||||
verbose: Optional[bool]
|
verbose: Optional[bool]
|
||||||
padding_side: Optional[str]
|
padding_side: Optional[str]
|
||||||
|
return_mm_token_type_ids: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
class ImagesKwargs(TypedDict, total=False):
|
class ImagesKwargs(TypedDict, total=False):
|
||||||
@ -455,6 +459,32 @@ class AllKwargsForChatTemplate(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiModalData:
|
||||||
|
"""
|
||||||
|
Dataclass that holds extra useful data for processing
|
||||||
|
multimodal data. Processors currently cannot return keys,
|
||||||
|
unless it is used in model's forward. Thus we have helper
|
||||||
|
methods that calculate and return useful data from processing
|
||||||
|
input multimodals (images/videos).
|
||||||
|
Note that this dataclass is aimed to be used only in vLLM
|
||||||
|
and we might change its API in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_image_tokens: list[int] = None
|
||||||
|
num_video_tokens: list[int] = None
|
||||||
|
num_audio_tokens: list[int] = None
|
||||||
|
num_image_patches: list[int] = None
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return hasattr(self, key) and getattr(self, key) is not None
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if hasattr(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
|
||||||
|
|
||||||
|
|
||||||
class ProcessorMixin(PushToHubMixin):
|
class ProcessorMixin(PushToHubMixin):
|
||||||
"""
|
"""
|
||||||
This is a mixin used to provide saving/loading functionality for all processor classes.
|
This is a mixin used to provide saving/loading functionality for all processor classes.
|
||||||
|
Loading…
Reference in New Issue
Block a user