mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge d50b8c9684
into ebfbcd42da
This commit is contained in:
commit
af826140c1
@ -1035,6 +1035,8 @@
|
||||
title: PaliGemma
|
||||
- local: model_doc/perceiver
|
||||
title: Perceiver
|
||||
- local: model_doc/perception_lm
|
||||
title: PerceptionLM
|
||||
- local: model_doc/phi4_multimodal
|
||||
title: Phi4 Multimodal
|
||||
- local: model_doc/pix2struct
|
||||
|
68
docs/source/en/model_doc/perception_lm.md
Normal file
68
docs/source/en/model_doc/perception_lm.md
Normal file
@ -0,0 +1,68 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# PerceptionLM
|
||||
|
||||
## Overview
|
||||
|
||||
The PerceptionLM model was proposed in [PerceptionLM: Open-Access Data and Models for Detailed Visual Understanding](https://ai.meta.com/research/publications/perceptionlm-open-access-data-and-models-for-detailed-visual-understanding/) by Jang Hyun Cho et al. It's a fully open, reproducible model for transparent research in image and video understanding. PLM consists of
|
||||
a vision encoder with a small scale (<8B parameters) LLM decoder.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Vision-language models are integral to computer vision research, yet many high-performing models
|
||||
remain closed-source, obscuring their data, design and training recipe. The research community
|
||||
has responded by using distillation from black-box models to label training data, achieving strong
|
||||
benchmark results, at the cost of measurable scientific progress. However, without knowing the details
|
||||
of the teacher model and its data sources, scientific progress remains difficult to measure. In this
|
||||
paper, we study building a Perception Language Model (PLM) in a fully open and reproducible
|
||||
framework for transparent research in image and video understanding. We analyze standard training
|
||||
pipelines without distillation from proprietary models and explore large-scale synthetic data to identify
|
||||
critical data gaps, particularly in detailed video understanding. To bridge these gaps, we release 2.8M
|
||||
human-labeled instances of fine-grained video question-answer pairs and spatio-temporally grounded
|
||||
video captions. Additionally, we introduce PLM–VideoBench, a suite for evaluating challenging video
|
||||
understanding tasks focusing on the ability to reason about “what”, “where”, “when”, and “how” of a
|
||||
video. We make our work fully reproducible by providing data, training recipes, code & models.*
|
||||
|
||||
|
||||
This model was contributed by [shumingh](https://huggingface.co/shumingh).
|
||||
The original code can be found [here](https://github.com/facebookresearch/perception_models).
|
||||
|
||||
|
||||
## PerceptionLMConfig
|
||||
|
||||
[[autodoc]] PerceptionLMConfig
|
||||
|
||||
## PerceptionLMProcessor
|
||||
|
||||
[[autodoc]] PerceptionLMProcessor
|
||||
|
||||
## PerceptionLMImageProcessorFast
|
||||
|
||||
[[autodoc]] PerceptionLMImageProcessorFast
|
||||
|
||||
## PerceptionLMVideoProcessor
|
||||
|
||||
[[autodoc]] PerceptionLMVideoProcessor
|
||||
|
||||
## PerceptionLMModel
|
||||
|
||||
[[autodoc]] PerceptionLMModel
|
||||
|
||||
## PerceptionLMForConditionalGeneration
|
||||
|
||||
[[autodoc]] PerceptionLMForConditionalGeneration
|
||||
- forward
|
@ -234,6 +234,7 @@ if TYPE_CHECKING:
|
||||
from .pegasus import *
|
||||
from .pegasus_x import *
|
||||
from .perceiver import *
|
||||
from .perception_lm import *
|
||||
from .persimmon import *
|
||||
from .phi import *
|
||||
from .phi3 import *
|
||||
|
@ -267,6 +267,8 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("pegasus", "PegasusConfig"),
|
||||
("pegasus_x", "PegasusXConfig"),
|
||||
("perceiver", "PerceiverConfig"),
|
||||
("perception_encoder", "TimmWrapperConfig"),
|
||||
("perception_lm", "PerceptionLMConfig"),
|
||||
("persimmon", "PersimmonConfig"),
|
||||
("phi", "PhiConfig"),
|
||||
("phi3", "Phi3Config"),
|
||||
@ -663,6 +665,8 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("pegasus", "Pegasus"),
|
||||
("pegasus_x", "PEGASUS-X"),
|
||||
("perceiver", "Perceiver"),
|
||||
("perception_encoder", "PerceptionEncoder"),
|
||||
("perception_lm", "PerceptionLM"),
|
||||
("persimmon", "Persimmon"),
|
||||
("phi", "Phi"),
|
||||
("phi3", "Phi3"),
|
||||
@ -869,6 +873,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
||||
("llama4_text", "llama4"),
|
||||
("blip_2_qformer", "blip_2"),
|
||||
("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
|
||||
("perception_encoder", "perception_lm"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -132,6 +132,7 @@ else:
|
||||
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
|
||||
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
|
||||
("perception_lm", ("PerceptionLMImageProcessorFast",)),
|
||||
("phi4_multimodal", ("Phi4MultimodalImageProcessorFast",)),
|
||||
("pix2struct", ("Pix2StructImageProcessor",)),
|
||||
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||
@ -597,7 +598,6 @@ class AutoImageProcessor:
|
||||
raise ValueError(
|
||||
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
|
||||
f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "
|
||||
|
@ -255,6 +255,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("pegasus", "PegasusModel"),
|
||||
("pegasus_x", "PegasusXModel"),
|
||||
("perceiver", "PerceiverModel"),
|
||||
("perception_encoder", "PerceptionEncoder"),
|
||||
("perception_lm", "PerceptionLMModel"),
|
||||
("persimmon", "PersimmonModel"),
|
||||
("phi", "PhiModel"),
|
||||
("phi3", "Phi3Model"),
|
||||
@ -933,6 +935,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("mistral3", "Mistral3ForConditionalGeneration"),
|
||||
("mllama", "MllamaForConditionalGeneration"),
|
||||
("paligemma", "PaliGemmaForConditionalGeneration"),
|
||||
("perception_lm", "PerceptionLMForConditionalGeneration"),
|
||||
("pix2struct", "Pix2StructForConditionalGeneration"),
|
||||
("pixtral", "LlavaForConditionalGeneration"),
|
||||
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
|
||||
|
@ -100,6 +100,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("owlv2", "Owlv2Processor"),
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
("paligemma", "PaliGemmaProcessor"),
|
||||
("perception_lm", "PerceptionLMProcessor"),
|
||||
("phi4_multimodal", "Phi4MultimodalProcessor"),
|
||||
("pix2struct", "Pix2StructProcessor"),
|
||||
("pixtral", "PixtralProcessor"),
|
||||
|
29
src/transformers/models/perception_lm/__init__.py
Normal file
29
src/transformers/models/perception_lm/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_perception_lm import *
|
||||
from .image_processing_perception_lm_fast import *
|
||||
from .modeling_perception_lm import *
|
||||
from .processing_perception_lm import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
@ -0,0 +1,88 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PerceptionLM model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||
from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PerceptionLMConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`PerceptionLMForConditionalGeneration`]. It is used to instantiate an
|
||||
PerceptionLM model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Example models:
|
||||
- [facebook/Perception-LM-1B](https://huggingface.co/facebook/Perception-LM-1B).
|
||||
- [facebook/Perception-LM-3B](https://huggingface.co/facebook/Perception-LM-3B).
|
||||
- [facebook/Perception-LM-8B](https://huggingface.co/facebook/Perception-LM-8B).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Union[TimmWrapperConfig, dict]`, *optional*, defaults to `TimmWrapperConfig()`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
text_config (`Union[PretrainedConfig, dict]`, *optional*, defaults to `LlamaConfig()`):
|
||||
The config object or dictionary of the text backbone.
|
||||
vision_use_cls_token (`bool`, *optional*, defaults to `True`):
|
||||
Whether CLS token is used in the vision backbone. If used, we remove CLS token embedding from vision output.
|
||||
projector_pooling_ratio (`int`, *optional*, defaults to 1):
|
||||
The pooling ratio used in the multimodal projector.
|
||||
image_token_id (`int`, *optional*, defaults to 128002):
|
||||
The image token index to encode the image prompt.
|
||||
video_token_id (`int`, *optional*, defaults to 128003):
|
||||
The video token index to encode the video prompt.
|
||||
"""
|
||||
|
||||
model_type = "perception_lm"
|
||||
sub_configs = {"text_config": AutoConfig, "vision_config": TimmWrapperConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
vision_use_cls_token=True,
|
||||
projector_pooling_ratio=1,
|
||||
image_token_id=128002,
|
||||
video_token_id=128003,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config = TimmWrapperConfig(**vision_config)
|
||||
elif isinstance(vision_config, TimmWrapperConfig):
|
||||
vision_config = vision_config
|
||||
elif vision_config is None:
|
||||
vision_config = TimmWrapperConfig()
|
||||
self.vision_config = vision_config
|
||||
self.vision_use_cls_token = vision_use_cls_token
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
||||
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
||||
elif text_config is None:
|
||||
text_config = CONFIG_MAPPING["llama"]()
|
||||
|
||||
self.text_config = text_config
|
||||
self.projector_pooling_ratio = projector_pooling_ratio
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["PerceptionLMConfig"]
|
@ -0,0 +1,615 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from timm.models.eva import checkpoint_filter_fn
|
||||
from tokenizers import AddedToken, processors
|
||||
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
LlamaConfig,
|
||||
LlamaTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import TikTokenConverter
|
||||
from transformers.models.auto.modeling_auto import AutoModel
|
||||
from transformers.models.perception_lm.configuration_perception_lm import (
|
||||
PerceptionLMConfig,
|
||||
)
|
||||
from transformers.models.perception_lm.image_processing_perception_lm_fast import (
|
||||
PerceptionLMImageProcessorFast,
|
||||
)
|
||||
from transformers.models.perception_lm.modeling_perception_lm import (
|
||||
PerceptionLMForConditionalGeneration,
|
||||
)
|
||||
from transformers.models.perception_lm.processing_perception_lm import (
|
||||
PerceptionLMProcessor,
|
||||
)
|
||||
from transformers.models.perception_lm.video_processing_perception_lm import (
|
||||
PerceptionLMVideoProcessor,
|
||||
)
|
||||
from transformers.models.timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig
|
||||
|
||||
|
||||
try:
|
||||
from transformers import LlamaTokenizerFast
|
||||
except ImportError as e:
|
||||
warnings.warn(e)
|
||||
warnings.warn(
|
||||
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
|
||||
)
|
||||
LlamaTokenizerFast = None
|
||||
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
```
|
||||
python src/transformers/models/perception_lm/convert_perception_lm_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/perception_lm/model_path --output_dir /output/path
|
||||
```
|
||||
|
||||
Thereafter, models can be loaded via:
|
||||
|
||||
```py
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("/output/path")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
||||
```
|
||||
|
||||
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
||||
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
||||
|
||||
If you want your tokenizer to add a bos automatically you should update the tokenizer._tokenizers.post_processor:
|
||||
|
||||
```py
|
||||
from tokenizers import processors
|
||||
bos = "<|begin_of_text|>"
|
||||
tokenizer._tokenizers.post_processor = processors.Sequence(
|
||||
[
|
||||
processors.ByteLevel(trim_offsets=False),
|
||||
processors.TemplateProcessing(
|
||||
single=f"{bos}:0 $A:0",
|
||||
pair=f"{bos}:0 $A:0 {bos}:1 $B:1",
|
||||
special_tokens=[
|
||||
(bos, tokenizer.encode(bos)),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
BOS_ADDED_TOKEN = AddedToken(
|
||||
"<|begin_of_text|>",
|
||||
single_word=False,
|
||||
lstrip=False,
|
||||
rstrip=False,
|
||||
normalized=False,
|
||||
special=True,
|
||||
)
|
||||
EOS_ADDED_TOKEN = AddedToken(
|
||||
"<|end_of_text|>",
|
||||
single_word=False,
|
||||
lstrip=False,
|
||||
rstrip=False,
|
||||
normalized=False,
|
||||
special=True,
|
||||
)
|
||||
EOT_ADDED_TOKEN = AddedToken(
|
||||
"<|eot_id|>",
|
||||
single_word=False,
|
||||
lstrip=False,
|
||||
rstrip=False,
|
||||
normalized=False,
|
||||
special=True,
|
||||
)
|
||||
|
||||
DEFAULT_SPECIAL_TOKENS = {
|
||||
"perception_lm": [
|
||||
"<|begin_of_text|>",
|
||||
"<|end_of_text|>",
|
||||
"<|image|>",
|
||||
"<|video|>",
|
||||
"<|reserved_special_token_2|>",
|
||||
"<|reserved_special_token_3|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"<|reserved_special_token_4|>",
|
||||
"<|eot_id|>", # End of turn
|
||||
]
|
||||
+ [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
|
||||
}
|
||||
|
||||
CHAT_TEMPLATE = (
|
||||
"{{- bos_token }}"
|
||||
"{%- if messages[0]['role'] == 'system' -%}"
|
||||
" {%- set system_message = messages[0]['content']|trim %}\n"
|
||||
" {%- set messages = messages[1:] %}\n"
|
||||
"{%- else %}"
|
||||
" {%- set system_message = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.' %}"
|
||||
"{%- endif %}"
|
||||
"{{- '<|start_header_id|>system<|end_header_id|>\\n\\n' }}"
|
||||
"{{- system_message }}"
|
||||
"{{- '<|eot_id|>' }}"
|
||||
"{%- for message in messages %}"
|
||||
"{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}"
|
||||
"{%- for content in message['content'] | selectattr('type', 'equalto', 'image') %}"
|
||||
"{{ '<|image|>' }}"
|
||||
"{%- endfor %}"
|
||||
"{%- for content in message['content'] | selectattr('type', 'equalto', 'video') %}"
|
||||
"{{ '<|video|>' }}"
|
||||
"{%- endfor %}"
|
||||
"{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %}"
|
||||
"{{- content['text'] | trim }}"
|
||||
"{%- endfor %}"
|
||||
"{{'<|eot_id|>' }}"
|
||||
"{%- endfor %}"
|
||||
"{%- if add_generation_prompt %}"
|
||||
"{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}"
|
||||
"{%- endif %}"
|
||||
)
|
||||
|
||||
|
||||
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
|
||||
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_weights(state_dict, index_dict, param_count, filename):
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, filename)
|
||||
print(f"Saved {filename}")
|
||||
return param_count
|
||||
|
||||
|
||||
def write_model(
|
||||
model_path,
|
||||
input_base_path,
|
||||
params,
|
||||
image_token_id,
|
||||
safe_serialization=True,
|
||||
tokenizer=None,
|
||||
num_shards=None,
|
||||
push_to_hub=False,
|
||||
):
|
||||
print("Converting the model.")
|
||||
num_shards = 1
|
||||
model_params = params.get("model", params)
|
||||
n_layers = model_params["n_layers"]
|
||||
n_heads = model_params["n_heads"]
|
||||
dim = model_params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = model_params.get("rope_theta", 10000.0)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
context_length = model_params["max_seqlen"]
|
||||
max_position_embeddings = context_length
|
||||
tie_word_embeddings = model_params.get("weight_tying", False)
|
||||
projector_pooling_ratio = model_params.get("pooling_ratio", 1)
|
||||
|
||||
if model_params.get("n_kv_heads", None) is not None:
|
||||
num_key_value_heads = model_params["n_kv_heads"] # for GQA / MQA
|
||||
key_value_dim = dims_per_head * num_key_value_heads
|
||||
else: # compatibility with other checkpoints
|
||||
num_key_value_heads = n_heads
|
||||
key_value_dim = dim
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w, n_heads, dim1=dim, dim2=dim):
|
||||
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_model_path:
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||
# Load weights
|
||||
if num_shards == 1:
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = torch.load(
|
||||
os.path.join(input_base_path, "consolidated.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
else:
|
||||
# Sharded
|
||||
checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
|
||||
print("Loading in order:", checkpoint_list)
|
||||
loaded = [
|
||||
torch.load(
|
||||
os.path.join(input_base_path, file),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
for file in checkpoint_list
|
||||
]
|
||||
param_count = 0
|
||||
index_dict = {"weight_map": {}}
|
||||
for layer_i in range(n_layers):
|
||||
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 2}.bin"
|
||||
assert num_shards == 1, "PerceptionLM does not support sharded weights"
|
||||
state_dict = {
|
||||
f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
|
||||
),
|
||||
f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"],
|
||||
n_heads=num_key_value_heads,
|
||||
dim1=key_value_dim,
|
||||
),
|
||||
f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight": loaded[
|
||||
f"layers.{layer_i}.attention.wv.weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight": loaded[
|
||||
f"layers.{layer_i}.attention.wo.weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight": loaded[
|
||||
f"layers.{layer_i}.feed_forward.w1.weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.mlp.down_proj.weight": loaded[
|
||||
f"layers.{layer_i}.feed_forward.w2.weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.mlp.up_proj.weight": loaded[
|
||||
f"layers.{layer_i}.feed_forward.w3.weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
],
|
||||
}
|
||||
state_dict[f"model.language_model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
print(f"Saved {filename}")
|
||||
|
||||
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 2}.bin"
|
||||
|
||||
state_dict = {
|
||||
"model.language_model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.language_model.norm.weight": loaded["norm.weight"],
|
||||
"model.multi_modal_projector.projector.0.weight": loaded["vision_projector.projector.0.weight"],
|
||||
"model.multi_modal_projector.projector.2.weight": loaded["vision_projector.projector.2.weight"],
|
||||
"model.multi_modal_projector.projector.0.bias": loaded["vision_projector.projector.0.bias"],
|
||||
"model.multi_modal_projector.projector.2.bias": loaded["vision_projector.projector.2.bias"],
|
||||
}
|
||||
if not tie_word_embeddings:
|
||||
state_dict["lm_head.weight"] = loaded["output.weight"]
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
print(f"Saved {filename}")
|
||||
|
||||
filename = f"pytorch_model-{n_layers + 2}-of-{n_layers + 2}.bin"
|
||||
state_dict = {k.replace("vision_model.", ""): v for k, v in loaded.items() if "vision_model" in k}
|
||||
vision_params = model_params["vision_model"]
|
||||
if vision_params["layers"] == 23 and vision_params["width"] == 1024:
|
||||
architecture = "vit_pe_core_large_patch14_336"
|
||||
elif vision_params["layers"] == 47 and vision_params["width"] == 1536:
|
||||
architecture = "vit_pe_core_gigantic_patch14_448"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported PE config: {vision_params['layers']} layers and {vision_params['width']} width"
|
||||
)
|
||||
|
||||
vision_config = TimmWrapperConfig.from_pretrained(
|
||||
f"timm/{architecture}.fb",
|
||||
model_args={
|
||||
"embed_dim": vision_params["width"],
|
||||
"depth": vision_params["layers"],
|
||||
"img_size": (vision_params["image_size"], vision_params["image_size"]),
|
||||
"global_pool": "",
|
||||
"use_post_transformer_norm": vision_params["use_ln_post"],
|
||||
"init_values": vision_params["ls_init_value"],
|
||||
"ref_feat_shape": (
|
||||
vision_params["image_size"] // vision_params["patch_size"],
|
||||
vision_params["image_size"] // vision_params["patch_size"],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
perception_encoder = AutoModel.from_config(vision_config)
|
||||
state_dict = checkpoint_filter_fn(state_dict, perception_encoder)
|
||||
state_dict = {"model.vision_tower.timm_model." + k: v for k, v in state_dict.items()}
|
||||
for k, v in state_dict.items():
|
||||
index_dict["weight_map"][k] = filename
|
||||
param_count += v.numel()
|
||||
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
||||
print(f"Saved {filename}")
|
||||
|
||||
# Write configs
|
||||
index_dict["metadata"] = {"total_size": param_count * 2}
|
||||
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
||||
ffn_dim_multiplier = model_params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in model_params else 1
|
||||
multiple_of = model_params["multiple_of"] if "multiple_of" in model_params else 256
|
||||
|
||||
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
|
||||
eos_token_id = [tokenizer.convert_tokens_to_ids(t) for t in ["<|end_of_text|>", "<|eot_id|>"]]
|
||||
|
||||
use_scaled_rope = model_params["use_scaled_rope"]
|
||||
if use_scaled_rope:
|
||||
rope_scaling = {
|
||||
"factor": model_params["rope_scale_factor"] * 1.0,
|
||||
"low_freq_factor": model_params.get("low_freq_factor", 1.0) * 1.0,
|
||||
"high_freq_factor": model_params.get("high_freq_factor", 4.0) * 1.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3",
|
||||
}
|
||||
else:
|
||||
rope_scaling = None
|
||||
|
||||
text_config = LlamaConfig(
|
||||
hidden_size=dim,
|
||||
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
|
||||
num_attention_heads=model_params["n_heads"],
|
||||
num_hidden_layers=model_params["n_layers"],
|
||||
rms_norm_eps=model_params["norm_eps"],
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
vocab_size=len(tokenizer),
|
||||
rope_theta=base,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
)
|
||||
|
||||
config = PerceptionLMConfig(
|
||||
text_config=text_config.to_dict(),
|
||||
vision_config=vision_config.to_dict(),
|
||||
projector_pooling_ratio=projector_pooling_ratio,
|
||||
vision_use_cls_token=vision_params["use_cls_token"],
|
||||
image_token_id=tokenizer.image_token_id,
|
||||
video_token_id=tokenizer.video_token_id,
|
||||
)
|
||||
|
||||
config.save_pretrained(tmp_model_path)
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
do_sample=False,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
generation_config.save_pretrained(tmp_model_path)
|
||||
|
||||
# Make space so we can load the model properly now.
|
||||
del state_dict
|
||||
# output_weight = loaded.get("output.weight", None)
|
||||
del loaded
|
||||
gc.collect()
|
||||
|
||||
print("Loading the checkpoint in a PerceptionLM model.")
|
||||
model = PerceptionLMForConditionalGeneration.from_pretrained(
|
||||
tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
|
||||
)
|
||||
# if not tie_word_embeddings:
|
||||
# if output_weight is None:
|
||||
# raise ValueError("Output weight/lm_head is not found in the checkpoint.")
|
||||
# model.lm_head.load_state_dict({"weight": output_weight})
|
||||
|
||||
# Avoid saving this as part of the config.
|
||||
del model.config._name_or_path
|
||||
model.config.torch_dtype = torch.bfloat16
|
||||
|
||||
print("Saving in the Transformers format.")
|
||||
if push_to_hub:
|
||||
print("Pushing to the hub.")
|
||||
model.push_to_hub(
|
||||
model_path,
|
||||
safe_serialization=safe_serialization,
|
||||
private=True,
|
||||
use_temp_dir=True,
|
||||
)
|
||||
else:
|
||||
print("Saving to disk.")
|
||||
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
class Llama3Converter(TikTokenConverter):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
special_tokens=None,
|
||||
context_length=11520,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(vocab_file, additional_special_tokens=special_tokens, **kwargs)
|
||||
tokenizer = self.converted()
|
||||
|
||||
self.converted_tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=tokenizer,
|
||||
bos_token="<|begin_of_text|>",
|
||||
eos_token="<|eot_id|>",
|
||||
model_input_names=["input_ids", "attention_mask"],
|
||||
model_max_length=context_length,
|
||||
clean_up_tokenization_spaces=True,
|
||||
extra_special_tokens={
|
||||
"image_token": "<|image|>",
|
||||
"video_token": "<|video|>",
|
||||
"pad_token": "<|end_of_text|>",
|
||||
},
|
||||
)
|
||||
self.converted_tokenizer.image_token_id = self.converted_tokenizer.encode(
|
||||
self.converted_tokenizer.image_token, add_special_tokens=False
|
||||
)[0]
|
||||
self.converted_tokenizer.video_token_id = self.converted_tokenizer.encode(
|
||||
self.converted_tokenizer.video_token, add_special_tokens=False
|
||||
)[0]
|
||||
self.update_post_processor(self.converted_tokenizer)
|
||||
# finer special_tokens_map.json
|
||||
self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN
|
||||
self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN
|
||||
|
||||
# We can't do this while building the tokenizer because we have no easy access to the bos token id
|
||||
def update_post_processor(self, tokenizer):
|
||||
tokenizer._tokenizer.post_processor = processors.Sequence(
|
||||
[
|
||||
processors.ByteLevel(trim_offsets=False),
|
||||
processors.TemplateProcessing(
|
||||
single="<|begin_of_text|> $A",
|
||||
pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1",
|
||||
special_tokens=[
|
||||
(
|
||||
"<|begin_of_text|>",
|
||||
tokenizer.convert_tokens_to_ids("<|begin_of_text|>"),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def write_tokenizer(
|
||||
tokenizer_path,
|
||||
input_tokenizer_path,
|
||||
special_tokens=None,
|
||||
params=None,
|
||||
push_to_hub=False,
|
||||
):
|
||||
print("Converting the tokenizer.")
|
||||
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
||||
context_length = params["model"]["max_seqlen"]
|
||||
tokenizer = Llama3Converter(
|
||||
input_tokenizer_path,
|
||||
special_tokens,
|
||||
context_length,
|
||||
).converted_tokenizer
|
||||
|
||||
tokenizer.image_token_id = tokenizer.encode(tokenizer.image_token, add_special_tokens=False)[0]
|
||||
processor_config = {
|
||||
"pooling_ratio": params["model"]["pooling_ratio"],
|
||||
"patch_size": params["model"]["vision_model"]["patch_size"],
|
||||
"processor_class": "PerceptionLMProcessor",
|
||||
}
|
||||
tile_size = params["model"]["vision_model"]["image_size"]
|
||||
|
||||
image_preprocessor_config = {
|
||||
"image_processor_type": "PerceptionLMImageProcessorFast",
|
||||
"vision_input_type": params["data"]["vision_input_type"],
|
||||
"tile_size": tile_size,
|
||||
"max_num_tiles": params["data"]["max_num_tiles"],
|
||||
"max_frame_tiles": 1,
|
||||
"size": {"height": tile_size, "width": tile_size},
|
||||
"do_resize": True,
|
||||
"do_rescale": True,
|
||||
"do_normalize": True,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5],
|
||||
}
|
||||
image_preprocessor = PerceptionLMImageProcessorFast(**image_preprocessor_config)
|
||||
video_preprocessor_config = {
|
||||
"video_processor_type": "PerceptionLMVideoProcessor",
|
||||
"size": {"height": tile_size, "width": tile_size},
|
||||
}
|
||||
video_preprocessor = PerceptionLMVideoProcessor(**video_preprocessor_config)
|
||||
processor = PerceptionLMProcessor(
|
||||
image_processor=image_preprocessor,
|
||||
video_processor=video_preprocessor,
|
||||
tokenizer=tokenizer,
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
**processor_config,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing a {tokenizer_class.__name__} to the Hub repo - {tokenizer_path}.")
|
||||
processor.push_to_hub(tokenizer_path, private=True, use_temp_dir=True)
|
||||
else:
|
||||
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
||||
processor.save_pretrained(tokenizer_path)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of Llama weights, which contains tokenizer.model and model folders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether or not to save using `safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_shards",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--special_tokens",
|
||||
default=None,
|
||||
type=list[str],
|
||||
help="The list of special tokens that should be added to the model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.special_tokens is None:
|
||||
# no special tokens by default
|
||||
args.special_tokens = DEFAULT_SPECIAL_TOKENS.get("perception_lm", [])
|
||||
|
||||
params = read_json(os.path.join(args.input_dir, "params.json"))
|
||||
|
||||
spm_path = os.path.join(args.input_dir, "tokenizer.model")
|
||||
tokenizer = write_tokenizer(
|
||||
args.output_dir,
|
||||
spm_path,
|
||||
special_tokens=args.special_tokens,
|
||||
params=params,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=args.input_dir,
|
||||
params=params,
|
||||
image_token_id=tokenizer.image_token_id,
|
||||
safe_serialization=args.safe_serialization,
|
||||
tokenizer=tokenizer,
|
||||
num_shards=args.num_shards,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,306 @@
|
||||
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for PerceptionLM."""
|
||||
|
||||
import math
|
||||
from functools import reduce
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import (
|
||||
BatchFeature,
|
||||
)
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
get_image_size,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
ChannelDimension,
|
||||
PILImageResampling,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
class PerceptionLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
vision_input_type: str = "thumb+tile"
|
||||
tile_size: int = 448
|
||||
max_num_tiles: int = 36
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast PerceptionLM image processor.",
|
||||
)
|
||||
class PerceptionLMImageProcessorFast(BaseImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
do_resize = True
|
||||
do_center_crop = False
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
size = {"width": 448, "height": 448} # for backward compatibility in tests
|
||||
valid_kwargs = PerceptionLMFastImageProcessorKwargs
|
||||
|
||||
def __init__(self, **kwargs: Unpack[PerceptionLMFastImageProcessorKwargs]) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _factors(n: int):
|
||||
"""Return all factors of a number."""
|
||||
return set(
|
||||
reduce(
|
||||
list.__add__,
|
||||
([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
|
||||
)
|
||||
)
|
||||
|
||||
def _find_supported_aspect_ratios(self):
|
||||
"""
|
||||
This function computes all the allowed aspect ratios for a fixed
|
||||
number of input chunks. The order of returned items matters for the result of `_fit_image_to_canvas` function.
|
||||
If tie exists in `_fit_image_to_canvas`, the latter in `_find_supported_aspect_ratios` wins.
|
||||
|
||||
For example, with `num_tiles=5`, it will return:
|
||||
{
|
||||
0.2: [(1, 5)],
|
||||
5.0: [(5, 1)],
|
||||
0.25: [(1, 4)],
|
||||
1.0: [(2, 2), (1, 1)],
|
||||
4.0: [(4, 1)],
|
||||
0.3333333333333333: [(1, 3)],
|
||||
3.0: [(3, 1)],
|
||||
0.5: [(1, 2)],
|
||||
2.0: [(2, 1)]
|
||||
}
|
||||
"""
|
||||
asp_dict = {}
|
||||
for chunk_size in range(self.max_num_tiles, 0, -1):
|
||||
_factors = sorted(self._factors(chunk_size))
|
||||
_asp_ratios = [(x, chunk_size // x) for x in _factors]
|
||||
for ratio in _asp_ratios:
|
||||
k = ratio[0] / ratio[1]
|
||||
if k not in asp_dict:
|
||||
asp_dict[k] = [ratio]
|
||||
else:
|
||||
asp_dict[k].append(ratio)
|
||||
return asp_dict
|
||||
|
||||
def _get_image_height_width(
|
||||
self, image_width: int, image_height: int, target_width: int, target_height: int
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Given image width, height and target width, height for the canvas, return the dimensions of how the image would be resized
|
||||
with aspect ratio preservation.
|
||||
"""
|
||||
scale = image_width / image_height
|
||||
|
||||
if scale > 1.0:
|
||||
# Width is larger than height
|
||||
|
||||
# Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
|
||||
rescaling_factor = min(target_width / image_width, target_height / image_height)
|
||||
|
||||
# Set new width to target width and height to the rescaled height.
|
||||
new_w = rescaling_factor * image_width
|
||||
new_h = math.floor(new_w / scale)
|
||||
|
||||
else:
|
||||
# Height is larger than width
|
||||
|
||||
# Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
|
||||
rescaling_factor = min(target_width / image_width, target_height / image_height)
|
||||
|
||||
# Set new height to target height and width to the rescaled width.
|
||||
new_h = rescaling_factor * image_height
|
||||
new_w = math.floor(new_h * scale)
|
||||
|
||||
return new_w, new_h
|
||||
|
||||
def _fit_image_to_canvas(self, img_width: int, img_height: int, tile_size: int):
|
||||
"""
|
||||
Given an image width, height and target number of chunks this function will see if the image
|
||||
can be fit into any of the canvases that can be build from arranging the tiles in a grid.
|
||||
If the image can be fit onto several canvases, it will return the canvas where the shorter edge
|
||||
of the image will be largest.
|
||||
"""
|
||||
# Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None.
|
||||
optimal_canvas = None
|
||||
optimal_image_width_height = None
|
||||
|
||||
scale = img_width / img_height
|
||||
|
||||
# Gather all potential supported image resolutions and iterate through them to find best match
|
||||
potential_arrangements = [
|
||||
item for sublist in self._find_supported_aspect_ratios().values() for item in sublist
|
||||
]
|
||||
for n_w, n_h in potential_arrangements:
|
||||
# Compute the canvas size
|
||||
canvas_width, canvas_height = n_w * tile_size, n_h * tile_size
|
||||
|
||||
# Check if image can fit into the canvas without downsampling
|
||||
if canvas_width >= img_width and canvas_height >= img_height:
|
||||
# If we did not find a good canvas yet, we will use the current one
|
||||
if optimal_canvas is None:
|
||||
# Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling
|
||||
optimal_canvas = (n_w, n_h)
|
||||
optimal_image_width_height = self._get_image_height_width(
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
target_width=n_w * tile_size,
|
||||
target_height=n_h * tile_size,
|
||||
)
|
||||
else:
|
||||
# If we already found an optimal canvas before, we will check if the shorter edge of the image will be larger than the current optimal canvas.
|
||||
# This means we can potentially upsample the image resolution which is beneficial to performance.
|
||||
image_width_height = self._get_image_height_width(
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
target_width=n_w * tile_size,
|
||||
target_height=n_h * tile_size,
|
||||
)
|
||||
# Llama3V dynamic tiling. Priortize biggest canvas.
|
||||
if (scale < 1.0 and (image_width_height[0] >= optimal_image_width_height[0])) or (
|
||||
scale >= 1.0 and (image_width_height[1] >= optimal_image_width_height[1])
|
||||
):
|
||||
optimal_canvas = (n_w, n_h)
|
||||
optimal_image_width_height = image_width_height
|
||||
return optimal_canvas
|
||||
|
||||
def _find_closest_aspect_ratio(self, img_width: int, img_height: int, tile_size: int) -> tuple:
|
||||
"""
|
||||
Given an image width, height and target number of chunks
|
||||
this function will find the closest supported aspect ratio.
|
||||
"""
|
||||
target_aspect_ratio = img_width / img_height
|
||||
asp_dict = self._find_supported_aspect_ratios()
|
||||
closest_aspect_ratio = None
|
||||
if target_aspect_ratio >= 1:
|
||||
closest_aspect_ratio = min(
|
||||
[k for k in asp_dict.keys() if k <= target_aspect_ratio],
|
||||
key=lambda x: abs(x - target_aspect_ratio),
|
||||
)
|
||||
tiles_given_aspect_ratio = asp_dict[closest_aspect_ratio]
|
||||
# select largest width
|
||||
return max(tiles_given_aspect_ratio, key=lambda x: x[0])
|
||||
else:
|
||||
closest_aspect_ratio = min(
|
||||
[k for k in asp_dict.keys() if k > target_aspect_ratio],
|
||||
key=lambda x: abs(1 / x - 1 / target_aspect_ratio),
|
||||
)
|
||||
tiles_given_aspect_ratio = asp_dict[closest_aspect_ratio]
|
||||
# select largest height
|
||||
return max(tiles_given_aspect_ratio, key=lambda x: x[1])
|
||||
|
||||
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
|
||||
# Split image into number of required tiles (width x height)
|
||||
batch_size, num_channels, height, width = image.size()
|
||||
image = image.view(batch_size, num_channels, nch, height // nch, ncw, width // ncw)
|
||||
# Permute dimensions to reorder the axes
|
||||
image = image.permute(0, 2, 4, 1, 3, 5).contiguous()
|
||||
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
|
||||
image = image.view(batch_size, ncw * nch, num_channels, height // nch, width // ncw)
|
||||
return image
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
tile_size: int,
|
||||
max_num_tiles: int,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
height, width = get_image_size(image, channel_dim=input_data_format)
|
||||
if max_num_tiles > 1:
|
||||
aspect_ratio = self._fit_image_to_canvas(img_width=width, img_height=height, tile_size=tile_size)
|
||||
if aspect_ratio is None:
|
||||
# If we did not find a canvas, we have to find the closest aspect ratio and downsample the image
|
||||
aspect_ratio = self._find_closest_aspect_ratio(img_width=width, img_height=height, tile_size=tile_size)
|
||||
else:
|
||||
aspect_ratio = (1, 1)
|
||||
new_width, new_height = aspect_ratio[0] * tile_size, aspect_ratio[1] * tile_size
|
||||
image = F.resize(image, (new_height, new_width), interpolation=resample)
|
||||
return image, aspect_ratio
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
do_rescale: Optional[bool],
|
||||
rescale_factor: Optional[Union[int, float]],
|
||||
do_normalize: Optional[bool],
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
tile_size: int,
|
||||
max_num_tiles: int,
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
disable_grouping: bool,
|
||||
**kwargs: Unpack[PerceptionLMFastImageProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
# Group images by size for batched transformation
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
if self.vision_input_type == "thumb+tile":
|
||||
thumbnails, _ = self.resize(stacked_images, tile_size, max_num_tiles=1)
|
||||
images_for_tiling, (tiles_w, tiles_h) = self.resize(
|
||||
stacked_images, tile_size, max_num_tiles=max_num_tiles
|
||||
)
|
||||
image_tiles = self._split(images_for_tiling, tiles_w, tiles_h)
|
||||
stacked_images = torch.cat([thumbnails.unsqueeze(1), image_tiles], dim=1)
|
||||
else: # vanilla single tile for low memory devices
|
||||
stacked_images, _ = self.resize(stacked_images, tile_size, max_num_tiles=1)
|
||||
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images,
|
||||
do_rescale,
|
||||
rescale_factor,
|
||||
do_normalize,
|
||||
image_mean,
|
||||
image_std,
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["PerceptionLMImageProcessorFast"]
|
515
src/transformers/models/perception_lm/modeling_perception_lm.py
Normal file
515
src/transformers/models/perception_lm/modeling_perception_lm.py
Normal file
@ -0,0 +1,515 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/perception_lm/modular_perception_lm.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_perception_lm.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from transformers.generation.utils import GenerationMixin
|
||||
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, can_return_tuple
|
||||
from ..auto import AutoModel
|
||||
from .configuration_perception_lm import PerceptionLMConfig
|
||||
|
||||
|
||||
class AdaptiveAvgPooling(nn.Module):
|
||||
def __init__(self, pooling_ratio=2):
|
||||
super(AdaptiveAvgPooling, self).__init__()
|
||||
self.pooling_ratio = pooling_ratio
|
||||
|
||||
def forward(self, hidden_states):
|
||||
b, num_tokens, c = hidden_states.shape
|
||||
h = int(math.sqrt(num_tokens))
|
||||
if h * h != num_tokens:
|
||||
raise ValueError(f"num_tokens {num_tokens} is expected to be a square number")
|
||||
|
||||
shape = (h // self.pooling_ratio, h // self.pooling_ratio)
|
||||
hidden_states = hidden_states.permute(0, 2, 1).reshape(b, -1, h, h)
|
||||
hidden_states = F.adaptive_avg_pool2d(hidden_states, shape)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PerceptionLMMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: PerceptionLMConfig):
|
||||
super().__init__()
|
||||
input_size = config.vision_config.model_args["embed_dim"]
|
||||
output_size = config.text_config.hidden_size
|
||||
self.projector = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(
|
||||
in_features=input_size,
|
||||
out_features=output_size,
|
||||
bias=True,
|
||||
),
|
||||
nn.GELU(),
|
||||
nn.Linear(
|
||||
in_features=output_size,
|
||||
out_features=output_size,
|
||||
bias=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.pooling = (
|
||||
AdaptiveAvgPooling(config.projector_pooling_ratio) if config.projector_pooling_ratio > 1 else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, features):
|
||||
features = features.permute(1, 0, 2) # NLD -> LND
|
||||
for layer in self.projector:
|
||||
features = layer(features)
|
||||
features = features.permute(1, 0, 2) # LND -> NLD
|
||||
features = self.pooling(features)
|
||||
return features
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class PerceptionLMPreTrainedModel(PreTrainedModel):
|
||||
config_class = PerceptionLMConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of PerceptionLM isn't meant for training from scratch - only
|
||||
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
|
||||
# https://github.com/haotian-liu/PerceptionLM/tree/main/perception_lm should serve for that purpose
|
||||
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for PerceptionLM outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class PerceptionLMModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
video_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for PerceptionLM causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class PerceptionLMCausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[tuple[torch.FloatTensor]] = None
|
||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||
video_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class PerceptionLMModel(PerceptionLMPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
|
||||
def __init__(self, config: PerceptionLMConfig):
|
||||
super().__init__(config)
|
||||
self.multi_modal_projector = PerceptionLMMultiModalProjector(config)
|
||||
self.language_model = AutoModel.from_config(config.text_config)
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.language_model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.language_model
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_tiles, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_tiles, num_patches, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values.flatten(0, 1))
|
||||
image_outputs = image_outputs.last_hidden_state
|
||||
if self.config.vision_use_cls_token:
|
||||
image_outputs = image_outputs[:, 1:, :]
|
||||
image_features = self.multi_modal_projector(image_outputs)
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[tuple, PerceptionLMModelOutputWithPast]:
|
||||
"""
|
||||
Forward pass of the PerceptionLM model.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
pixel_values (`torch.FloatTensor`, *optional*):
|
||||
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
|
||||
pixel_values_videos (`torch.FloatTensor`, *optional*):
|
||||
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
position_ids (`torch.LongTensor`, *optional*):
|
||||
Indices of positions of each input sequence token in the position embeddings.
|
||||
past_key_values (`list[torch.FloatTensor]`, *optional*):
|
||||
Precomputed key and value hidden states for fast autoregressive generation.
|
||||
inputs_embeds (`torch.FloatTensor`, *optional*):
|
||||
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
|
||||
use_cache (`bool`, *optional*):
|
||||
Whether or not to use past key values to speed up decoding.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor`, *optional*):
|
||||
Position indices for caching.
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
|
||||
Number of logits to keep.
|
||||
**lm_kwargs:
|
||||
Additional keyword arguments for the language model.
|
||||
|
||||
Returns:
|
||||
[`PerceptionLMModelOutputWithPast`] or `tuple`:
|
||||
Model outputs as a `PerceptionLMModelOutputWithPast` if `return_dict=True`, otherwise a tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both (pixel_values or pixel_values_videos) and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
image_features = None
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values.to(inputs_embeds),
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
self.check_mask_feature_size_match(special_image_mask, image_features)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
video_features = None
|
||||
if pixel_values_videos is not None:
|
||||
video_features = self.get_image_features(
|
||||
pixel_values=pixel_values_videos.to(inputs_embeds),
|
||||
)
|
||||
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
self.check_mask_feature_size_match(special_video_mask, video_features)
|
||||
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
video_features = video_features.to(inputs_embeds)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
return PerceptionLMModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
past_key_values=outputs.past_key_values,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
video_hidden_states=(video_features if pixel_values_videos is not None else None),
|
||||
)
|
||||
|
||||
def check_mask_feature_size_match(self, media_mask, media_features):
|
||||
media_token_count = media_mask.sum()
|
||||
media_feature_size = media_features.size()[:-1].numel()
|
||||
if media_token_count != media_feature_size:
|
||||
raise ValueError(
|
||||
f"The number of tokens in the media mask ({media_token_count}) does not match the number of features in the media features ({media_feature_size}. Features shape: {media_features.shape})"
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: PerceptionLMConfig, **super_kwargs):
|
||||
super().__init__(config, **super_kwargs)
|
||||
self.model = PerceptionLMModel(config)
|
||||
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
@property
|
||||
def vision_tower(self):
|
||||
return self.model.vision_tower
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.model.set_input_embeddings(new_embeddings)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
pixel_values_videos=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
return model_inputs
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]:
|
||||
"""
|
||||
Forward pass for the PerceptionLMForConditionalGeneration model.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
pixel_values (`torch.FloatTensor`, *optional*):
|
||||
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
|
||||
pixel_values_videos (`torch.FloatTensor`, *optional*):
|
||||
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
position_ids (`torch.LongTensor`, *optional*):
|
||||
Indices of positions of each input sequence token in the position embeddings.
|
||||
past_key_values (`list[torch.FloatTensor]`, *optional*):
|
||||
Precomputed key and value hidden states for fast autoregressive generation.
|
||||
inputs_embeds (`torch.FloatTensor`, *optional*):
|
||||
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
|
||||
labels (`torch.LongTensor`, *optional*):
|
||||
Labels for computing the language modeling loss.
|
||||
use_cache (`bool`, *optional*):
|
||||
Whether or not to use past key values to speed up decoding.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor`, *optional*):
|
||||
Position indices for caching.
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
|
||||
Number of logits to keep.
|
||||
**lm_kwargs:
|
||||
Additional keyword arguments for the language model.
|
||||
|
||||
Returns:
|
||||
[`PerceptionLMCausalLMOutputWithPast`] or `tuple`:
|
||||
Model outputs as a `PerceptionLMCausalLMOutputWithPast` if `return_dict=True`, otherwise a tuple.
|
||||
"""
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.text_config.vocab_size,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
return PerceptionLMCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
video_hidden_states=outputs.video_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["PerceptionLMForConditionalGeneration", "PerceptionLMPreTrainedModel", "PerceptionLMModel"]
|
444
src/transformers/models/perception_lm/modular_perception_lm.py
Normal file
444
src/transformers/models/perception_lm/modular_perception_lm.py
Normal file
@ -0,0 +1,444 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch PerceptionLM model."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.generation.utils import GenerationMixin
|
||||
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ..auto import AutoModel
|
||||
from ..llava.modeling_llava import (
|
||||
LlavaCausalLMOutputWithPast,
|
||||
LlavaModel,
|
||||
LlavaModelOutputWithPast,
|
||||
LlavaPreTrainedModel,
|
||||
)
|
||||
from .configuration_perception_lm import PerceptionLMConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "PerceptionLMConfig"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "facebook/Perception-LM-1B"
|
||||
|
||||
|
||||
class AdaptiveAvgPooling(nn.Module):
|
||||
def __init__(self, pooling_ratio=2):
|
||||
super(AdaptiveAvgPooling, self).__init__()
|
||||
self.pooling_ratio = pooling_ratio
|
||||
|
||||
def forward(self, hidden_states):
|
||||
b, num_tokens, c = hidden_states.shape
|
||||
h = int(math.sqrt(num_tokens))
|
||||
if h * h != num_tokens:
|
||||
raise ValueError(f"num_tokens {num_tokens} is expected to be a square number")
|
||||
|
||||
shape = (h // self.pooling_ratio, h // self.pooling_ratio)
|
||||
hidden_states = hidden_states.permute(0, 2, 1).reshape(b, -1, h, h)
|
||||
hidden_states = F.adaptive_avg_pool2d(hidden_states, shape)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PerceptionLMMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: PerceptionLMConfig):
|
||||
super().__init__()
|
||||
input_size = config.vision_config.model_args["embed_dim"]
|
||||
output_size = config.text_config.hidden_size
|
||||
self.projector = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(
|
||||
in_features=input_size,
|
||||
out_features=output_size,
|
||||
bias=True,
|
||||
),
|
||||
nn.GELU(),
|
||||
nn.Linear(
|
||||
in_features=output_size,
|
||||
out_features=output_size,
|
||||
bias=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
self.pooling = (
|
||||
AdaptiveAvgPooling(config.projector_pooling_ratio) if config.projector_pooling_ratio > 1 else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, features):
|
||||
features = features.permute(1, 0, 2) # NLD -> LND
|
||||
for layer in self.projector:
|
||||
features = layer(features)
|
||||
features = features.permute(1, 0, 2) # LND -> NLD
|
||||
features = self.pooling(features)
|
||||
return features
|
||||
|
||||
|
||||
class PerceptionLMPreTrainedModel(LlavaPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
|
||||
|
||||
class PerceptionLMModelOutputWithPast(LlavaModelOutputWithPast):
|
||||
video_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class PerceptionLMCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
|
||||
video_hidden_states: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class PerceptionLMModel(LlavaModel):
|
||||
def __init__(self, config: PerceptionLMConfig):
|
||||
super().__init__(config)
|
||||
del self.vision_tower
|
||||
self.vision_tower = AutoModel.from_config(config.vision_config)
|
||||
self.multi_modal_projector = PerceptionLMMultiModalProjector(config)
|
||||
self.language_model = AutoModel.from_config(config.text_config)
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_tiles, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_tiles, num_patches, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values.flatten(0, 1))
|
||||
image_outputs = image_outputs.last_hidden_state
|
||||
if self.config.vision_use_cls_token:
|
||||
image_outputs = image_outputs[:, 1:, :]
|
||||
image_features = self.multi_modal_projector(image_outputs)
|
||||
return image_features
|
||||
|
||||
def check_mask_feature_size_match(self, media_mask, media_features):
|
||||
media_token_count = media_mask.sum()
|
||||
media_feature_size = media_features.size()[:-1].numel()
|
||||
if media_token_count != media_feature_size:
|
||||
raise ValueError(
|
||||
f"The number of tokens in the media mask ({media_token_count}) does not match the number of features in the media features ({media_feature_size}. Features shape: {media_features.shape})"
|
||||
)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[tuple, PerceptionLMModelOutputWithPast]:
|
||||
"""
|
||||
Forward pass of the PerceptionLM model.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
pixel_values (`torch.FloatTensor`, *optional*):
|
||||
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
|
||||
pixel_values_videos (`torch.FloatTensor`, *optional*):
|
||||
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
position_ids (`torch.LongTensor`, *optional*):
|
||||
Indices of positions of each input sequence token in the position embeddings.
|
||||
past_key_values (`list[torch.FloatTensor]`, *optional*):
|
||||
Precomputed key and value hidden states for fast autoregressive generation.
|
||||
inputs_embeds (`torch.FloatTensor`, *optional*):
|
||||
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
|
||||
use_cache (`bool`, *optional*):
|
||||
Whether or not to use past key values to speed up decoding.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor`, *optional*):
|
||||
Position indices for caching.
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
|
||||
Number of logits to keep.
|
||||
**lm_kwargs:
|
||||
Additional keyword arguments for the language model.
|
||||
|
||||
Returns:
|
||||
[`PerceptionLMModelOutputWithPast`] or `tuple`:
|
||||
Model outputs as a `PerceptionLMModelOutputWithPast` if `return_dict=True`, otherwise a tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both (pixel_values or pixel_values_videos) and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
image_features = None
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values.to(inputs_embeds),
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
self.check_mask_feature_size_match(special_image_mask, image_features)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
video_features = None
|
||||
if pixel_values_videos is not None:
|
||||
video_features = self.get_image_features(
|
||||
pixel_values=pixel_values_videos.to(inputs_embeds),
|
||||
)
|
||||
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
self.check_mask_feature_size_match(special_video_mask, video_features)
|
||||
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
video_features = video_features.to(inputs_embeds)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
return PerceptionLMModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
past_key_values=outputs.past_key_values,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
video_hidden_states=(video_features if pixel_values_videos is not None else None),
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: PerceptionLMConfig, **super_kwargs):
|
||||
super().__init__(config, **super_kwargs)
|
||||
self.model = PerceptionLMModel(config)
|
||||
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
return self.model.language_model
|
||||
|
||||
@property
|
||||
def vision_tower(self):
|
||||
return self.model.vision_tower
|
||||
|
||||
@property
|
||||
def multi_modal_projector(self):
|
||||
return self.model.multi_modal_projector
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.model.set_input_embeddings(new_embeddings)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
pixel_values_videos=None,
|
||||
attention_mask=None,
|
||||
cache_position=None,
|
||||
logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if cache_position[0] == 0:
|
||||
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
||||
# Otherwise we need pixel values to be passed to model
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
model_inputs["pixel_values_videos"] = pixel_values_videos
|
||||
return model_inputs
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]:
|
||||
"""
|
||||
Forward pass for the PerceptionLMForConditionalGeneration model.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.LongTensor`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
pixel_values (`torch.FloatTensor`, *optional*):
|
||||
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
|
||||
pixel_values_videos (`torch.FloatTensor`, *optional*):
|
||||
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
position_ids (`torch.LongTensor`, *optional*):
|
||||
Indices of positions of each input sequence token in the position embeddings.
|
||||
past_key_values (`list[torch.FloatTensor]`, *optional*):
|
||||
Precomputed key and value hidden states for fast autoregressive generation.
|
||||
inputs_embeds (`torch.FloatTensor`, *optional*):
|
||||
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
|
||||
labels (`torch.LongTensor`, *optional*):
|
||||
Labels for computing the language modeling loss.
|
||||
use_cache (`bool`, *optional*):
|
||||
Whether or not to use past key values to speed up decoding.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor`, *optional*):
|
||||
Position indices for caching.
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
|
||||
Number of logits to keep.
|
||||
**lm_kwargs:
|
||||
Additional keyword arguments for the language model.
|
||||
|
||||
Returns:
|
||||
[`PerceptionLMCausalLMOutputWithPast`] or `tuple`:
|
||||
Model outputs as a `PerceptionLMCausalLMOutputWithPast` if `return_dict=True`, otherwise a tuple.
|
||||
"""
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.text_config.vocab_size,
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
return PerceptionLMCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=outputs.image_hidden_states,
|
||||
video_hidden_states=outputs.video_hidden_states,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PerceptionLMForConditionalGeneration",
|
||||
"PerceptionLMPreTrainedModel",
|
||||
"PerceptionLMModel",
|
||||
]
|
@ -0,0 +1,207 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for PerceptionLM.
|
||||
"""
|
||||
|
||||
from typing import Iterable, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, get_image_size, to_numpy_array
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PerceptionLMProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class PerceptionLMProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a PerceptionLM processor which wraps a PerceptionLM image processor, a PerceptionLM video processor, and a tokenizer into a single processor.
|
||||
|
||||
[`PerceptionLMProcessor`] offers all the functionalities of [`PerceptionLMImageProcessorFast`], [`PerceptionLMVideoProcessor`], and the tokenizer (e.g. [`LlamaTokenizerFast`]). See the
|
||||
[`~PerceptionLMProcessor.__call__`] and [`~PerceptionLMProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
video_processor ([`PerceptionLMVideoProcessor`], *optional*):
|
||||
The video processor to process video inputs.
|
||||
image_processor ([`PerceptionLMImageProcessorFast`], *optional*):
|
||||
The image processor to process image inputs.
|
||||
tokenizer ([`LlamaTokenizerFast`] or similar, *optional*):
|
||||
The tokenizer to process text inputs.
|
||||
patch_size (`int`, *optional*):
|
||||
Patch size from the vision tower.
|
||||
chat_template (`str`, *optional*):
|
||||
A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
|
||||
pooling_ratio (`int`, *optional*, defaults to 2):
|
||||
Pooling ratio for vision tokens. If not 1, 2D adaptive pooling is applied over projected vision tokens.
|
||||
"""
|
||||
|
||||
attributes = ["video_processor", "image_processor", "tokenizer"]
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_processor=None,
|
||||
image_processor=None,
|
||||
tokenizer=None,
|
||||
patch_size=None,
|
||||
chat_template=None,
|
||||
pooling_ratio=2,
|
||||
**kwargs,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.pooling_ratio = pooling_ratio
|
||||
self.image_token = tokenizer.image_token
|
||||
self.video_token = tokenizer.video_token
|
||||
self.image_token_id = tokenizer.image_token_id
|
||||
self.video_token_id = tokenizer.video_token_id
|
||||
super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[PerceptionLMProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Prepares a batch containing one or more sequences of text and/or images and/or videos.
|
||||
|
||||
If `text` is provided, it is tokenized using the tokenizer.
|
||||
If `images` is provided, they are processed using the image processor.
|
||||
If `videos` is provided, they are processed using the video processor.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
|
||||
The image or batch of images to be processed. Each image can be a PIL image, NumPy array, or PyTorch tensor.
|
||||
Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, *optional*):
|
||||
The sequence or batch of sequences to be tokenized. Each sequence can be a string.
|
||||
videos (`Any`, *optional*):
|
||||
The video or batch of videos to be processed.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is provided.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is provided).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is provided.
|
||||
- **pixel_values_videos** -- Video pixel values to be fed to a model. Returned when `videos` is provided.
|
||||
"""
|
||||
if text is None:
|
||||
raise ValueError(
|
||||
"You have to specify at least `text` input. Optionally, you can also specify `images` or `videos`."
|
||||
)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
PerceptionLMProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
else:
|
||||
image_inputs = {}
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
|
||||
else:
|
||||
videos_inputs = {}
|
||||
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
# try to expand inputs in processing if we have the necessary parts
|
||||
prompt_strings = []
|
||||
|
||||
pixel_values = iter(image_inputs.get("pixel_values", []))
|
||||
pixel_values_videos = iter(videos_inputs.get("pixel_values_videos", []))
|
||||
for sample in text:
|
||||
# Replace the media token with the expanded media token sequence
|
||||
sample = self._expand_media_tokens(sample, self.tokenizer.image_token, pixel_values)
|
||||
sample = self._expand_media_tokens(sample, self.tokenizer.video_token, pixel_values_videos)
|
||||
prompt_strings.append(sample)
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
|
||||
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image", "video"])
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
||||
|
||||
def _expand_media_tokens(self, sample, media_token: str, media_iter: Iterable):
|
||||
media_count = sample.count(media_token)
|
||||
if media_count > 0:
|
||||
media_list = [next(media_iter) for _ in range(media_count)]
|
||||
sample_splits = sample.split(media_token)
|
||||
media_token_list = []
|
||||
for media in media_list:
|
||||
height, width = get_image_size(to_numpy_array(media))
|
||||
num_tiles = media.shape[0]
|
||||
num_media_tokens = (
|
||||
(height // self.patch_size // self.pooling_ratio)
|
||||
* (width // self.patch_size // self.pooling_ratio)
|
||||
* num_tiles
|
||||
)
|
||||
media_token_list.append(num_media_tokens)
|
||||
sample = ""
|
||||
for i, num_media_tokens in enumerate(media_token_list):
|
||||
sample += sample_splits[i]
|
||||
sample += media_token * num_media_tokens
|
||||
sample += sample_splits[-1]
|
||||
return sample
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PerceptionLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PerceptionLMTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
|
||||
|
||||
__all__ = ["PerceptionLMProcessor"]
|
@ -0,0 +1,53 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Video processor class for PerceptionLM."""
|
||||
|
||||
from ...image_utils import (
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
)
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import is_vision_available
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import (
|
||||
BaseVideoProcessor,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
|
||||
class PerceptionLMFastVideoProcessorInitKwargs(VideosKwargs): ...
|
||||
|
||||
|
||||
@requires(backends=("torchvision",))
|
||||
class PerceptionLMVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = IMAGENET_STANDARD_MEAN
|
||||
image_std = IMAGENET_STANDARD_STD
|
||||
size = {"height": 448, "width": 448}
|
||||
do_resize = True
|
||||
do_center_crop = False
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
valid_kwargs = PerceptionLMFastVideoProcessorInitKwargs
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[PerceptionLMFastVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["PerceptionLMVideoProcessor"]
|
0
tests/models/perception_lm/__init__.py
Normal file
0
tests/models/perception_lm/__init__.py
Normal file
@ -0,0 +1,224 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import PerceptionLMImageProcessorFast
|
||||
|
||||
|
||||
class PerceptionLMImageProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
tile_size=16,
|
||||
do_normalize=True,
|
||||
image_mean=IMAGENET_STANDARD_MEAN,
|
||||
image_std=IMAGENET_STANDARD_STD,
|
||||
do_convert_rgb=True,
|
||||
max_num_tiles=4,
|
||||
vision_input_type="thumb+tile",
|
||||
resample=Image.Resampling.BICUBIC, # dummy value
|
||||
size={"shortest_edge": 20}, # dummy value
|
||||
):
|
||||
super().__init__()
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.tile_size = tile_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.max_num_tiles = max_num_tiles
|
||||
self.vision_input_type = vision_input_type
|
||||
self.resample = resample
|
||||
self.size = size
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"tile_size": self.tile_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
"max_num_tiles": self.max_num_tiles,
|
||||
"vision_input_type": self.vision_input_type,
|
||||
"resample": self.resample,
|
||||
"size": self.size,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
||||
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class PerceptionLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
fast_image_processing_class = PerceptionLMImageProcessorFast if is_torchvision_available() else None
|
||||
test_slow_image_processor = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = PerceptionLMImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "tile_size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
self.assertTrue(hasattr(image_processing, "max_num_tiles"))
|
||||
self.assertTrue(hasattr(image_processing, "vision_input_type"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.tile_size, 16)
|
||||
self.assertEqual(image_processor.max_num_tiles, 4)
|
||||
self.assertEqual(image_processor.vision_input_type, "thumb+tile")
|
||||
|
||||
image_processor = image_processing_class.from_dict(
|
||||
self.image_processor_dict, tile_size=42, max_num_tiles=9
|
||||
)
|
||||
self.assertEqual(image_processor.tile_size, 42)
|
||||
self.assertEqual(image_processor.max_num_tiles, 9)
|
||||
self.assertEqual(image_processor.vision_input_type, "thumb+tile")
|
||||
|
||||
def test_call_pil(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 5, 3, 16, 16)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 5, 3, 16, 16)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 5, 3, 16, 16)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 5, 3, 16, 16)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
# Initialize image_processing
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
||||
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (1, 5, 3, 16, 16)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 5, 3, 16, 16)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
@unittest.skip(reason="PerceptionLMImageProcessor doesn't treat 4 channel PIL and numpy consistently yet")
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
def test_nested_input(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
||||
|
||||
# Test batched as a list of images
|
||||
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 5, 3, 16, 16)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched as a nested list of images, where each sublist is one batch
|
||||
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
|
||||
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = (7, 5, 3, 16, 16)
|
||||
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
|
||||
|
||||
# Image processor should return same pixel values, independently of ipnut format
|
||||
self.assertTrue((encoded_images_nested == encoded_images).all())
|
474
tests/models/perception_lm/test_modeling_perception_lm.py
Normal file
474
tests/models/perception_lm/test_modeling_perception_lm.py
Normal file
@ -0,0 +1,474 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch PerceptionLM model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
PerceptionLMConfig,
|
||||
PerceptionLMForConditionalGeneration,
|
||||
PerceptionLMModel,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class PerceptionLMVisionText2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
image_token_id=0,
|
||||
video_token_id=2,
|
||||
seq_length=7,
|
||||
tie_word_embeddings=True,
|
||||
projector_pooling_ratio=1,
|
||||
text_config={
|
||||
"model_type": "llama",
|
||||
"seq_length": 7,
|
||||
"is_training": True,
|
||||
"use_input_mask": True,
|
||||
"use_token_type_ids": False,
|
||||
"use_labels": True,
|
||||
"vocab_size": 99,
|
||||
"hidden_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 37,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"max_position_embeddings": 512,
|
||||
"type_vocab_size": 16,
|
||||
"type_sequence_label_size": 2,
|
||||
"initializer_range": 0.02,
|
||||
"num_labels": 3,
|
||||
"num_choices": 4,
|
||||
"pad_token_id": 1,
|
||||
},
|
||||
is_training=True,
|
||||
vision_config={
|
||||
"architecture": "vit_pe_core_large_patch14_336",
|
||||
"model_args": {
|
||||
"embed_dim": 64,
|
||||
"img_size": (14, 14),
|
||||
"depth": 2,
|
||||
"global_pool": "",
|
||||
"use_post_transformer_norm": False,
|
||||
"init_values": 0.1,
|
||||
"ref_feat_shape": (1, 1),
|
||||
},
|
||||
},
|
||||
):
|
||||
self.parent = parent
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
self.pad_token_id = text_config["pad_token_id"]
|
||||
|
||||
self.num_hidden_layers = text_config["num_hidden_layers"]
|
||||
self.vocab_size = text_config["vocab_size"]
|
||||
self.hidden_size = text_config["hidden_size"]
|
||||
self.num_attention_heads = text_config["num_attention_heads"]
|
||||
self.is_training = is_training
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
self.batch_size = 3
|
||||
self.num_tiles = 1
|
||||
self.num_frames = 1
|
||||
self.num_channels = 3
|
||||
self.image_size = self.vision_config["model_args"]["img_size"][0]
|
||||
self.num_image_tokens = (self.vision_config["model_args"]["img_size"][0] // 14) ** 2
|
||||
self.num_video_tokens = (self.vision_config["model_args"]["img_size"][0] // 14) ** 2
|
||||
self.seq_length = seq_length + self.num_image_tokens
|
||||
self.encoder_seq_length = self.seq_length
|
||||
|
||||
def get_config(self):
|
||||
return PerceptionLMConfig(
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
vision_use_cls_token=True,
|
||||
image_token_id=self.image_token_id,
|
||||
video_token_id=self.video_token_id,
|
||||
tie_word_embeddings=self.tie_word_embeddings,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
[
|
||||
self.batch_size,
|
||||
self.num_tiles,
|
||||
self.num_channels,
|
||||
self.vision_config["model_args"]["img_size"][0],
|
||||
self.vision_config["model_args"]["img_size"][1],
|
||||
]
|
||||
)
|
||||
pixel_values_videos = floats_tensor(
|
||||
[
|
||||
self.batch_size,
|
||||
self.num_frames,
|
||||
self.num_channels,
|
||||
self.vision_config["model_args"]["img_size"][0],
|
||||
self.vision_config["model_args"]["img_size"][1],
|
||||
]
|
||||
)
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, pixel_values_videos
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs()
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
|
||||
input_ids[input_ids == config.image_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == config.video_token_id] = self.pad_token_id
|
||||
input_ids[:, : self.num_image_tokens] = config.image_token_id
|
||||
input_ids[:, self.num_image_tokens : self.num_video_tokens + self.num_image_tokens] = config.video_token_id
|
||||
|
||||
inputs_dict = {
|
||||
"pixel_values": pixel_values,
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class PerceptionLMForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `PerceptionLMForConditionalGeneration`.
|
||||
"""
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
PerceptionLMModel,
|
||||
PerceptionLMForConditionalGeneration,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PerceptionLMVisionText2TextModelTester(self)
|
||||
common_properties = [
|
||||
"image_token_id",
|
||||
"video_token_id",
|
||||
]
|
||||
self.config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=PerceptionLMConfig,
|
||||
has_text_modality=False,
|
||||
common_properties=common_properties,
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
del inputs["pixel_values_videos"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images doesn't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == PerceptionLMModel:
|
||||
continue
|
||||
model = model_class(config).to(torch_device)
|
||||
_ = model(**input_dict) # successful forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = input_dict["input_ids"][:1]
|
||||
pixel_values = input_dict["pixel_values"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
def test_training(self):
|
||||
self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else ()
|
||||
super().test_training()
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else ()
|
||||
super().test_training_gradient_checkpointing()
|
||||
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else ()
|
||||
super().test_training_gradient_checkpointing_use_reentrant()
|
||||
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else ()
|
||||
super().test_training_gradient_checkpointing_use_reentrant_false()
|
||||
|
||||
@unittest.skip(reason="Timm Eva (PE) weights cannot be fully constructed in _init_weights")
|
||||
def test_can_init_all_missing_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Timm Eva (PE) weights cannot be fully constructed in _init_weights")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="PE/TIMM's attention implementation is self configured and won't raise ValueError on global attention implementation."
|
||||
)
|
||||
def test_flash_attn_2_can_dispatch_composite_models(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
|
||||
)
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ViT PE cannot be tested with meta device")
|
||||
def test_can_be_initialized_on_meta(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("ViT PE cannot be tested with meta device")
|
||||
def test_can_load_with_meta_device_context_manager(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM")
|
||||
def test_generate_from_inputs_embeds_0_greedy(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM")
|
||||
def test_generate_from_inputs_embeds_1_beam_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
## Skip flash attention releated tests below
|
||||
## correct configuration:
|
||||
## from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2", "vision_config": "eager"}
|
||||
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
|
||||
def test_flash_attn_2_fp32_ln(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
|
||||
def test_flash_attn_2_from_config(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
|
||||
def test_eager_matches_sdpa_generate_with_dynamic_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
|
||||
def test_flash_attn_2_inference_equivalence(self):
|
||||
pass
|
||||
|
||||
|
||||
TEST_MODEL_PATH = "shumingh/plm_1b_hf"
|
||||
|
||||
|
||||
@require_torch
|
||||
class PerceptionLMForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH)
|
||||
self.image_file = hf_hub_download(
|
||||
repo_id="shumingh/perception_lm_test_images",
|
||||
filename="14496_0.PNG",
|
||||
repo_type="dataset",
|
||||
)
|
||||
self.video_file = hf_hub_download(
|
||||
repo_id="shumingh/perception_lm_test_videos",
|
||||
filename="GUWR5TyiY-M_000012_000022.mp4",
|
||||
repo_type="dataset",
|
||||
)
|
||||
self.conversation1 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": self.image_file},
|
||||
{"type": "text", "text": "Describe the bar plot in the image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
self.conversation2 = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": self.video_file,
|
||||
},
|
||||
{"type": "text", "text": "Can you describe the video in detail?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test(self):
|
||||
model = PerceptionLMForConditionalGeneration.from_pretrained(
|
||||
TEST_MODEL_PATH, load_in_4bit=True, cache_dir="./"
|
||||
)
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
[self.conversation1],
|
||||
num_frames=32,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_load_backend="decord",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
).to(torch_device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=18)
|
||||
input_length = inputs["input_ids"].shape[1]
|
||||
generate_ids_without_inputs = generate_ids[:, input_length:]
|
||||
|
||||
EXPECTED_DECODED_TEXT = "The bar plot displays the values of four categories: step, horror, mood, and lumber" # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
self.processor.decode(generate_ids_without_inputs[0], skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_small_model_integration_test_batched(self):
|
||||
model = PerceptionLMForConditionalGeneration.from_pretrained(TEST_MODEL_PATH, load_in_4bit=True)
|
||||
processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH)
|
||||
inputs = processor.apply_chat_template(
|
||||
[self.conversation1, self.conversation2],
|
||||
num_frames=32,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
video_load_backend="decord",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
).to(torch_device)
|
||||
|
||||
generate_ids = model.generate(**inputs, max_new_tokens=18)
|
||||
input_length = inputs["input_ids"].shape[1]
|
||||
generate_ids_without_inputs = generate_ids[:, input_length:]
|
||||
|
||||
EXPECTED_DECODED_TEXT = ['The bar plot displays the values of four categories: step, horror, mood, and lumber', 'The video shows a group of people in green shirts and white shorts performing a jump rope routine'] # fmt: skip
|
||||
|
||||
self.assertEqual(
|
||||
processor.batch_decode(generate_ids_without_inputs, skip_special_tokens=True),
|
||||
EXPECTED_DECODED_TEXT,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_bitsandbytes
|
||||
def test_generation_no_images(self):
|
||||
# model_id = "facebook/Perception-LM-1B"
|
||||
model = PerceptionLMForConditionalGeneration.from_pretrained(TEST_MODEL_PATH, load_in_4bit=True)
|
||||
processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH)
|
||||
|
||||
# Prepare inputs with no images
|
||||
inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device)
|
||||
|
||||
# Make sure that `generate` works
|
||||
_ = model.generate(**inputs, max_new_tokens=20)
|
145
tests/models/perception_lm/test_processor_perception_lm.py
Normal file
145
tests/models/perception_lm/test_processor_perception_lm.py
Normal file
@ -0,0 +1,145 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PerceptionLMProcessor,
|
||||
)
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import PerceptionLMImageProcessorFast, PerceptionLMVideoProcessor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
TEST_MODEL_PATH = "shumingh/plm_1b_hf"
|
||||
|
||||
|
||||
@require_vision
|
||||
class PerceptionLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = PerceptionLMProcessor
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
image_processor = PerceptionLMImageProcessorFast(
|
||||
tile_size=448, max_num_tiles=4, vision_input_type="thumb+tile"
|
||||
)
|
||||
video_processor = PerceptionLMVideoProcessor()
|
||||
tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_PATH)
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>", "<|video|>"]})
|
||||
processor_kwargs = cls.prepare_processor_dict()
|
||||
processor = PerceptionLMProcessor(
|
||||
image_processor=image_processor, video_processor=video_processor, tokenizer=tokenizer, **processor_kwargs
|
||||
)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
cls.image_token_id = processor.image_token_id
|
||||
cls.video_token_id = processor.video_token_id
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
||||
|
||||
@staticmethod
|
||||
def prepare_processor_dict():
|
||||
return {
|
||||
"chat_template": CHAT_TEMPLATE,
|
||||
"patch_size": 14,
|
||||
"pooling_ratio": 2,
|
||||
} # fmt: skip
|
||||
|
||||
def test_chat_template_is_saved(self):
|
||||
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
|
||||
# chat templates aren't serialized to json in processors
|
||||
self.assertFalse("chat_template" in processor_dict_loaded.keys())
|
||||
|
||||
# they have to be saved as separate file and loaded back from that file
|
||||
# so we check if the same template is loaded
|
||||
processor_dict = self.prepare_processor_dict()
|
||||
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
|
||||
|
||||
def test_image_token_filling(self):
|
||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
||||
# Important to check with non square image
|
||||
image = torch.randn((1, 3, 450, 500))
|
||||
# 5 tiles (thumbnail tile + 4 tiles)
|
||||
# 448/patch_size/pooling_ratio = 16 => 16*16 tokens per tile
|
||||
expected_image_tokens = 16 * 16 * 5
|
||||
image_token_index = processor.image_token_id
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
inputs = processor(
|
||||
text=[processor.apply_chat_template(messages)],
|
||||
images=[image],
|
||||
return_tensors="pt",
|
||||
)
|
||||
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
|
||||
self.assertEqual(expected_image_tokens, image_tokens)
|
||||
|
||||
|
||||
CHAT_TEMPLATE = (
|
||||
"{{- bos_token }}"
|
||||
"{%- if messages[0]['role'] == 'system' -%}"
|
||||
" {%- set system_message = messages[0]['content']|trim %}\n"
|
||||
" {%- set messages = messages[1:] %}\n"
|
||||
"{%- else %}"
|
||||
" {%- set system_message = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.' %}"
|
||||
"{%- endif %}"
|
||||
"{{- '<|start_header_id|>system<|end_header_id|>\\n\\n' }}"
|
||||
"{{- system_message }}"
|
||||
"{{- '<|eot_id|>' }}"
|
||||
"{%- for message in messages %}"
|
||||
"{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}"
|
||||
"{%- for content in message['content'] | selectattr('type', 'equalto', 'image') %}"
|
||||
"{{ '<|image|>' }}"
|
||||
"{%- endfor %}"
|
||||
"{%- for content in message['content'] | selectattr('type', 'equalto', 'video') %}"
|
||||
"{{ '<|video|>' }}"
|
||||
"{%- endfor %}"
|
||||
"{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %}"
|
||||
"{{- content['text'] | trim }}"
|
||||
"{%- endfor %}"
|
||||
"{{'<|eot_id|>' }}"
|
||||
"{%- endfor %}"
|
||||
"{%- if add_generation_prompt %}"
|
||||
"{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}"
|
||||
"{%- endif %}"
|
||||
)
|
@ -0,0 +1,127 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
pass
|
||||
|
||||
if is_vision_available():
|
||||
if is_torchvision_available():
|
||||
from transformers import PerceptionLMVideoProcessor
|
||||
|
||||
|
||||
class PerceptionLMVideoProcessingTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=5,
|
||||
num_frames=8,
|
||||
num_channels=3,
|
||||
min_resolution=30,
|
||||
max_resolution=80,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_normalize=True,
|
||||
image_mean=IMAGENET_STANDARD_MEAN,
|
||||
image_std=IMAGENET_STANDARD_STD,
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"height": 20, "width": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_frames = num_frames
|
||||
self.num_channels = num_channels
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_video_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_video_shape(self, images):
|
||||
return self.num_frames, self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
||||
|
||||
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
|
||||
videos = prepare_video_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_frames=self.num_frames,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
return videos
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class PerceptionLMVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
|
||||
fast_video_processing_class = PerceptionLMVideoProcessor if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_processor_tester = PerceptionLMVideoProcessingTester(self)
|
||||
|
||||
@property
|
||||
def video_processor_dict(self):
|
||||
return self.video_processor_tester.prepare_video_processor_dict()
|
||||
|
||||
def test_video_processor_properties(self):
|
||||
video_processing = self.fast_video_processing_class(**self.video_processor_dict)
|
||||
self.assertTrue(hasattr(video_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(video_processing, "size"))
|
||||
self.assertTrue(hasattr(video_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(video_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(video_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(video_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(video_processing, "image_std"))
|
||||
self.assertTrue(hasattr(video_processing, "do_convert_rgb"))
|
||||
|
||||
def test_video_processor_from_dict_with_kwargs(self):
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
|
||||
self.assertEqual(video_processor.size, {"height": 20, "width": 20})
|
||||
self.assertEqual(video_processor.crop_size, {"height": 18, "width": 18})
|
||||
|
||||
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(video_processor.size, {"height": 42, "width": 42})
|
||||
self.assertEqual(video_processor.crop_size, {"height": 84, "width": 84})
|
Loading…
Reference in New Issue
Block a user