This commit is contained in:
Shuming Hu 2025-07-02 14:49:17 -07:00 committed by GitHub
commit af826140c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 3308 additions and 1 deletions

View File

@ -1035,6 +1035,8 @@
title: PaliGemma
- local: model_doc/perceiver
title: Perceiver
- local: model_doc/perception_lm
title: PerceptionLM
- local: model_doc/phi4_multimodal
title: Phi4 Multimodal
- local: model_doc/pix2struct

View File

@ -0,0 +1,68 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# PerceptionLM
## Overview
The PerceptionLM model was proposed in [PerceptionLM: Open-Access Data and Models for Detailed Visual Understanding](https://ai.meta.com/research/publications/perceptionlm-open-access-data-and-models-for-detailed-visual-understanding/) by Jang Hyun Cho et al. It's a fully open, reproducible model for transparent research in image and video understanding. PLM consists of
a vision encoder with a small scale (<8B parameters) LLM decoder.
The abstract from the paper is the following:
*Vision-language models are integral to computer vision research, yet many high-performing models
remain closed-source, obscuring their data, design and training recipe. The research community
has responded by using distillation from black-box models to label training data, achieving strong
benchmark results, at the cost of measurable scientific progress. However, without knowing the details
of the teacher model and its data sources, scientific progress remains difficult to measure. In this
paper, we study building a Perception Language Model (PLM) in a fully open and reproducible
framework for transparent research in image and video understanding. We analyze standard training
pipelines without distillation from proprietary models and explore large-scale synthetic data to identify
critical data gaps, particularly in detailed video understanding. To bridge these gaps, we release 2.8M
human-labeled instances of fine-grained video question-answer pairs and spatio-temporally grounded
video captions. Additionally, we introduce PLMVideoBench, a suite for evaluating challenging video
understanding tasks focusing on the ability to reason about “what”, “where”, “when”, and “how” of a
video. We make our work fully reproducible by providing data, training recipes, code & models.*
This model was contributed by [shumingh](https://huggingface.co/shumingh).
The original code can be found [here](https://github.com/facebookresearch/perception_models).
## PerceptionLMConfig
[[autodoc]] PerceptionLMConfig
## PerceptionLMProcessor
[[autodoc]] PerceptionLMProcessor
## PerceptionLMImageProcessorFast
[[autodoc]] PerceptionLMImageProcessorFast
## PerceptionLMVideoProcessor
[[autodoc]] PerceptionLMVideoProcessor
## PerceptionLMModel
[[autodoc]] PerceptionLMModel
## PerceptionLMForConditionalGeneration
[[autodoc]] PerceptionLMForConditionalGeneration
- forward

View File

@ -234,6 +234,7 @@ if TYPE_CHECKING:
from .pegasus import *
from .pegasus_x import *
from .perceiver import *
from .perception_lm import *
from .persimmon import *
from .phi import *
from .phi3 import *

View File

@ -267,6 +267,8 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("pegasus", "PegasusConfig"),
("pegasus_x", "PegasusXConfig"),
("perceiver", "PerceiverConfig"),
("perception_encoder", "TimmWrapperConfig"),
("perception_lm", "PerceptionLMConfig"),
("persimmon", "PersimmonConfig"),
("phi", "PhiConfig"),
("phi3", "Phi3Config"),
@ -663,6 +665,8 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("pegasus", "Pegasus"),
("pegasus_x", "PEGASUS-X"),
("perceiver", "Perceiver"),
("perception_encoder", "PerceptionEncoder"),
("perception_lm", "PerceptionLM"),
("persimmon", "Persimmon"),
("phi", "Phi"),
("phi3", "Phi3"),
@ -869,6 +873,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
("llama4_text", "llama4"),
("blip_2_qformer", "blip_2"),
("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
("perception_encoder", "perception_lm"),
]
)

View File

@ -132,6 +132,7 @@ else:
("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")),
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")),
("perception_lm", ("PerceptionLMImageProcessorFast",)),
("phi4_multimodal", ("Phi4MultimodalImageProcessorFast",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
@ -597,7 +598,6 @@ class AutoImageProcessor:
raise ValueError(
"This image processor cannot be instantiated. Please make sure you have `Pillow` installed."
)
raise ValueError(
f"Unrecognized image processor in {pretrained_model_name_or_path}. Should have a "
f"`image_processor_type` key in its {IMAGE_PROCESSOR_NAME} of {CONFIG_NAME}, or one of the following "

View File

@ -255,6 +255,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
("pegasus", "PegasusModel"),
("pegasus_x", "PegasusXModel"),
("perceiver", "PerceiverModel"),
("perception_encoder", "PerceptionEncoder"),
("perception_lm", "PerceptionLMModel"),
("persimmon", "PersimmonModel"),
("phi", "PhiModel"),
("phi3", "Phi3Model"),
@ -933,6 +935,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("perception_lm", "PerceptionLMForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("pixtral", "LlavaForConditionalGeneration"),
("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),

View File

@ -100,6 +100,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"),
("perception_lm", "PerceptionLMProcessor"),
("phi4_multimodal", "Phi4MultimodalProcessor"),
("pix2struct", "Pix2StructProcessor"),
("pixtral", "PixtralProcessor"),

View File

@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_perception_lm import *
from .image_processing_perception_lm_fast import *
from .modeling_perception_lm import *
from .processing_perception_lm import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PerceptionLM model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig
from ..timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig
logger = logging.get_logger(__name__)
class PerceptionLMConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PerceptionLMForConditionalGeneration`]. It is used to instantiate an
PerceptionLM model according to the specified arguments, defining the model architecture.
Example models:
- [facebook/Perception-LM-1B](https://huggingface.co/facebook/Perception-LM-1B).
- [facebook/Perception-LM-3B](https://huggingface.co/facebook/Perception-LM-3B).
- [facebook/Perception-LM-8B](https://huggingface.co/facebook/Perception-LM-8B).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Union[TimmWrapperConfig, dict]`, *optional*, defaults to `TimmWrapperConfig()`):
The config object or dictionary of the vision backbone.
text_config (`Union[PretrainedConfig, dict]`, *optional*, defaults to `LlamaConfig()`):
The config object or dictionary of the text backbone.
vision_use_cls_token (`bool`, *optional*, defaults to `True`):
Whether CLS token is used in the vision backbone. If used, we remove CLS token embedding from vision output.
projector_pooling_ratio (`int`, *optional*, defaults to 1):
The pooling ratio used in the multimodal projector.
image_token_id (`int`, *optional*, defaults to 128002):
The image token index to encode the image prompt.
video_token_id (`int`, *optional*, defaults to 128003):
The video token index to encode the video prompt.
"""
model_type = "perception_lm"
sub_configs = {"text_config": AutoConfig, "vision_config": TimmWrapperConfig}
def __init__(
self,
vision_config=None,
text_config=None,
vision_use_cls_token=True,
projector_pooling_ratio=1,
image_token_id=128002,
video_token_id=128003,
**kwargs,
):
self.image_token_id = image_token_id
self.video_token_id = video_token_id
if isinstance(vision_config, dict):
vision_config = TimmWrapperConfig(**vision_config)
elif isinstance(vision_config, TimmWrapperConfig):
vision_config = vision_config
elif vision_config is None:
vision_config = TimmWrapperConfig()
self.vision_config = vision_config
self.vision_use_cls_token = vision_use_cls_token
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["llama"]()
self.text_config = text_config
self.projector_pooling_ratio = projector_pooling_ratio
super().__init__(**kwargs)
__all__ = ["PerceptionLMConfig"]

View File

@ -0,0 +1,615 @@
# coding=utf-8
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import os
import tempfile
import warnings
import torch
from timm.models.eva import checkpoint_filter_fn
from tokenizers import AddedToken, processors
from transformers import (
GenerationConfig,
LlamaConfig,
LlamaTokenizer,
PreTrainedTokenizerFast,
)
from transformers.convert_slow_tokenizer import TikTokenConverter
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.perception_lm.configuration_perception_lm import (
PerceptionLMConfig,
)
from transformers.models.perception_lm.image_processing_perception_lm_fast import (
PerceptionLMImageProcessorFast,
)
from transformers.models.perception_lm.modeling_perception_lm import (
PerceptionLMForConditionalGeneration,
)
from transformers.models.perception_lm.processing_perception_lm import (
PerceptionLMProcessor,
)
from transformers.models.perception_lm.video_processing_perception_lm import (
PerceptionLMVideoProcessor,
)
from transformers.models.timm_wrapper.configuration_timm_wrapper import TimmWrapperConfig
try:
from transformers import LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
LlamaTokenizerFast = None
"""
Sample usage:
```
python src/transformers/models/perception_lm/convert_perception_lm_weights_to_hf.py \
--input_dir /path/to/downloaded/perception_lm/model_path --output_dir /output/path
```
Thereafter, models can be loaded via:
```py
from transformers import LlamaForCausalLM, LlamaTokenizer
model = LlamaForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
If you want your tokenizer to add a bos automatically you should update the tokenizer._tokenizers.post_processor:
```py
from tokenizers import processors
bos = "<|begin_of_text|>"
tokenizer._tokenizers.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single=f"{bos}:0 $A:0",
pair=f"{bos}:0 $A:0 {bos}:1 $B:1",
special_tokens=[
(bos, tokenizer.encode(bos)),
],
),
]
)
```
"""
BOS_ADDED_TOKEN = AddedToken(
"<|begin_of_text|>",
single_word=False,
lstrip=False,
rstrip=False,
normalized=False,
special=True,
)
EOS_ADDED_TOKEN = AddedToken(
"<|end_of_text|>",
single_word=False,
lstrip=False,
rstrip=False,
normalized=False,
special=True,
)
EOT_ADDED_TOKEN = AddedToken(
"<|eot_id|>",
single_word=False,
lstrip=False,
rstrip=False,
normalized=False,
special=True,
)
DEFAULT_SPECIAL_TOKENS = {
"perception_lm": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|image|>",
"<|video|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # End of turn
]
+ [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
}
CHAT_TEMPLATE = (
"{{- bos_token }}"
"{%- if messages[0]['role'] == 'system' -%}"
" {%- set system_message = messages[0]['content']|trim %}\n"
" {%- set messages = messages[1:] %}\n"
"{%- else %}"
" {%- set system_message = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.' %}"
"{%- endif %}"
"{{- '<|start_header_id|>system<|end_header_id|>\\n\\n' }}"
"{{- system_message }}"
"{{- '<|eot_id|>' }}"
"{%- for message in messages %}"
"{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}"
"{%- for content in message['content'] | selectattr('type', 'equalto', 'image') %}"
"{{ '<|image|>' }}"
"{%- endfor %}"
"{%- for content in message['content'] | selectattr('type', 'equalto', 'video') %}"
"{{ '<|video|>' }}"
"{%- endfor %}"
"{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %}"
"{{- content['text'] | trim }}"
"{%- endfor %}"
"{{'<|eot_id|>' }}"
"{%- endfor %}"
"{%- if add_generation_prompt %}"
"{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}"
"{%- endif %}"
)
def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_weights(state_dict, index_dict, param_count, filename):
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, filename)
print(f"Saved {filename}")
return param_count
def write_model(
model_path,
input_base_path,
params,
image_token_id,
safe_serialization=True,
tokenizer=None,
num_shards=None,
push_to_hub=False,
):
print("Converting the model.")
num_shards = 1
model_params = params.get("model", params)
n_layers = model_params["n_layers"]
n_heads = model_params["n_heads"]
dim = model_params["dim"]
dims_per_head = dim // n_heads
base = model_params.get("rope_theta", 10000.0)
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
context_length = model_params["max_seqlen"]
max_position_embeddings = context_length
tie_word_embeddings = model_params.get("weight_tying", False)
projector_pooling_ratio = model_params.get("pooling_ratio", 1)
if model_params.get("n_kv_heads", None) is not None:
num_key_value_heads = model_params["n_kv_heads"] # for GQA / MQA
key_value_dim = dims_per_head * num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
key_value_dim = dim
# permute for sliced rotary
def permute(w, n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
with tempfile.TemporaryDirectory() as tmp_model_path:
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(
os.path.join(input_base_path, "consolidated.pth"),
map_location="cpu",
weights_only=True,
)
else:
# Sharded
checkpoint_list = sorted([file for file in os.listdir(input_base_path) if file.endswith(".pth")])
print("Loading in order:", checkpoint_list)
loaded = [
torch.load(
os.path.join(input_base_path, file),
map_location="cpu",
weights_only=True,
)
for file in checkpoint_list
]
param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 2}.bin"
assert num_shards == 1, "PerceptionLM does not support sharded weights"
state_dict = {
f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"], n_heads=n_heads
),
f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"],
n_heads=num_key_value_heads,
dim1=key_value_dim,
),
f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight": loaded[
f"layers.{layer_i}.attention.wv.weight"
],
f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight": loaded[
f"layers.{layer_i}.attention.wo.weight"
],
f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight": loaded[
f"layers.{layer_i}.feed_forward.w1.weight"
],
f"model.language_model.layers.{layer_i}.mlp.down_proj.weight": loaded[
f"layers.{layer_i}.feed_forward.w2.weight"
],
f"model.language_model.layers.{layer_i}.mlp.up_proj.weight": loaded[
f"layers.{layer_i}.feed_forward.w3.weight"
],
f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded[
f"layers.{layer_i}.attention_norm.weight"
],
f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded[
f"layers.{layer_i}.ffn_norm.weight"
],
}
state_dict[f"model.language_model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
print(f"Saved {filename}")
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 2}.bin"
state_dict = {
"model.language_model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.language_model.norm.weight": loaded["norm.weight"],
"model.multi_modal_projector.projector.0.weight": loaded["vision_projector.projector.0.weight"],
"model.multi_modal_projector.projector.2.weight": loaded["vision_projector.projector.2.weight"],
"model.multi_modal_projector.projector.0.bias": loaded["vision_projector.projector.0.bias"],
"model.multi_modal_projector.projector.2.bias": loaded["vision_projector.projector.2.bias"],
}
if not tie_word_embeddings:
state_dict["lm_head.weight"] = loaded["output.weight"]
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
print(f"Saved {filename}")
filename = f"pytorch_model-{n_layers + 2}-of-{n_layers + 2}.bin"
state_dict = {k.replace("vision_model.", ""): v for k, v in loaded.items() if "vision_model" in k}
vision_params = model_params["vision_model"]
if vision_params["layers"] == 23 and vision_params["width"] == 1024:
architecture = "vit_pe_core_large_patch14_336"
elif vision_params["layers"] == 47 and vision_params["width"] == 1536:
architecture = "vit_pe_core_gigantic_patch14_448"
else:
raise ValueError(
f"Unsupported PE config: {vision_params['layers']} layers and {vision_params['width']} width"
)
vision_config = TimmWrapperConfig.from_pretrained(
f"timm/{architecture}.fb",
model_args={
"embed_dim": vision_params["width"],
"depth": vision_params["layers"],
"img_size": (vision_params["image_size"], vision_params["image_size"]),
"global_pool": "",
"use_post_transformer_norm": vision_params["use_ln_post"],
"init_values": vision_params["ls_init_value"],
"ref_feat_shape": (
vision_params["image_size"] // vision_params["patch_size"],
vision_params["image_size"] // vision_params["patch_size"],
),
},
)
perception_encoder = AutoModel.from_config(vision_config)
state_dict = checkpoint_filter_fn(state_dict, perception_encoder)
state_dict = {"model.vision_tower.timm_model." + k: v for k, v in state_dict.items()}
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
print(f"Saved {filename}")
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
ffn_dim_multiplier = model_params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in model_params else 1
multiple_of = model_params["multiple_of"] if "multiple_of" in model_params else 256
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
eos_token_id = [tokenizer.convert_tokens_to_ids(t) for t in ["<|end_of_text|>", "<|eot_id|>"]]
use_scaled_rope = model_params["use_scaled_rope"]
if use_scaled_rope:
rope_scaling = {
"factor": model_params["rope_scale_factor"] * 1.0,
"low_freq_factor": model_params.get("low_freq_factor", 1.0) * 1.0,
"high_freq_factor": model_params.get("high_freq_factor", 4.0) * 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
else:
rope_scaling = None
text_config = LlamaConfig(
hidden_size=dim,
intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of),
num_attention_heads=model_params["n_heads"],
num_hidden_layers=model_params["n_layers"],
rms_norm_eps=model_params["norm_eps"],
num_key_value_heads=num_key_value_heads,
vocab_size=len(tokenizer),
rope_theta=base,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
)
config = PerceptionLMConfig(
text_config=text_config.to_dict(),
vision_config=vision_config.to_dict(),
projector_pooling_ratio=projector_pooling_ratio,
vision_use_cls_token=vision_params["use_cls_token"],
image_token_id=tokenizer.image_token_id,
video_token_id=tokenizer.video_token_id,
)
config.save_pretrained(tmp_model_path)
generation_config = GenerationConfig(
do_sample=False,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
)
generation_config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
# output_weight = loaded.get("output.weight", None)
del loaded
gc.collect()
print("Loading the checkpoint in a PerceptionLM model.")
model = PerceptionLMForConditionalGeneration.from_pretrained(
tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
)
# if not tie_word_embeddings:
# if output_weight is None:
# raise ValueError("Output weight/lm_head is not found in the checkpoint.")
# model.lm_head.load_state_dict({"weight": output_weight})
# Avoid saving this as part of the config.
del model.config._name_or_path
model.config.torch_dtype = torch.bfloat16
print("Saving in the Transformers format.")
if push_to_hub:
print("Pushing to the hub.")
model.push_to_hub(
model_path,
safe_serialization=safe_serialization,
private=True,
use_temp_dir=True,
)
else:
print("Saving to disk.")
model.save_pretrained(model_path, safe_serialization=safe_serialization)
class Llama3Converter(TikTokenConverter):
def __init__(
self,
vocab_file,
special_tokens=None,
context_length=11520,
**kwargs,
):
super().__init__(vocab_file, additional_special_tokens=special_tokens, **kwargs)
tokenizer = self.converted()
self.converted_tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
bos_token="<|begin_of_text|>",
eos_token="<|eot_id|>",
model_input_names=["input_ids", "attention_mask"],
model_max_length=context_length,
clean_up_tokenization_spaces=True,
extra_special_tokens={
"image_token": "<|image|>",
"video_token": "<|video|>",
"pad_token": "<|end_of_text|>",
},
)
self.converted_tokenizer.image_token_id = self.converted_tokenizer.encode(
self.converted_tokenizer.image_token, add_special_tokens=False
)[0]
self.converted_tokenizer.video_token_id = self.converted_tokenizer.encode(
self.converted_tokenizer.video_token, add_special_tokens=False
)[0]
self.update_post_processor(self.converted_tokenizer)
# finer special_tokens_map.json
self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN
self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN
# We can't do this while building the tokenizer because we have no easy access to the bos token id
def update_post_processor(self, tokenizer):
tokenizer._tokenizer.post_processor = processors.Sequence(
[
processors.ByteLevel(trim_offsets=False),
processors.TemplateProcessing(
single="<|begin_of_text|> $A",
pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1",
special_tokens=[
(
"<|begin_of_text|>",
tokenizer.convert_tokens_to_ids("<|begin_of_text|>"),
),
],
),
]
)
def write_tokenizer(
tokenizer_path,
input_tokenizer_path,
special_tokens=None,
params=None,
push_to_hub=False,
):
print("Converting the tokenizer.")
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
context_length = params["model"]["max_seqlen"]
tokenizer = Llama3Converter(
input_tokenizer_path,
special_tokens,
context_length,
).converted_tokenizer
tokenizer.image_token_id = tokenizer.encode(tokenizer.image_token, add_special_tokens=False)[0]
processor_config = {
"pooling_ratio": params["model"]["pooling_ratio"],
"patch_size": params["model"]["vision_model"]["patch_size"],
"processor_class": "PerceptionLMProcessor",
}
tile_size = params["model"]["vision_model"]["image_size"]
image_preprocessor_config = {
"image_processor_type": "PerceptionLMImageProcessorFast",
"vision_input_type": params["data"]["vision_input_type"],
"tile_size": tile_size,
"max_num_tiles": params["data"]["max_num_tiles"],
"max_frame_tiles": 1,
"size": {"height": tile_size, "width": tile_size},
"do_resize": True,
"do_rescale": True,
"do_normalize": True,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5],
}
image_preprocessor = PerceptionLMImageProcessorFast(**image_preprocessor_config)
video_preprocessor_config = {
"video_processor_type": "PerceptionLMVideoProcessor",
"size": {"height": tile_size, "width": tile_size},
}
video_preprocessor = PerceptionLMVideoProcessor(**video_preprocessor_config)
processor = PerceptionLMProcessor(
image_processor=image_preprocessor,
video_processor=video_preprocessor,
tokenizer=tokenizer,
chat_template=CHAT_TEMPLATE,
**processor_config,
)
if push_to_hub:
print(f"Pushing a {tokenizer_class.__name__} to the Hub repo - {tokenizer_path}.")
processor.push_to_hub(tokenizer_path, private=True, use_temp_dir=True)
else:
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
processor.save_pretrained(tokenizer_path)
return tokenizer
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of Llama weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally.",
action="store_true",
default=False,
)
parser.add_argument(
"--safe_serialization",
action="store_true",
default=True,
help="Whether or not to save using `safetensors`.",
)
parser.add_argument(
"--num_shards",
default=None,
type=int,
help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
)
parser.add_argument(
"--special_tokens",
default=None,
type=list[str],
help="The list of special tokens that should be added to the model.",
)
args = parser.parse_args()
if args.special_tokens is None:
# no special tokens by default
args.special_tokens = DEFAULT_SPECIAL_TOKENS.get("perception_lm", [])
params = read_json(os.path.join(args.input_dir, "params.json"))
spm_path = os.path.join(args.input_dir, "tokenizer.model")
tokenizer = write_tokenizer(
args.output_dir,
spm_path,
special_tokens=args.special_tokens,
params=params,
push_to_hub=args.push_to_hub,
)
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
params=params,
image_token_id=tokenizer.image_token_id,
safe_serialization=args.safe_serialization,
tokenizer=tokenizer,
num_shards=args.num_shards,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,306 @@
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for PerceptionLM."""
import math
from functools import reduce
from typing import Optional, Union
import numpy as np
from ...image_processing_utils import (
BatchFeature,
)
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
get_image_size,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
PILImageResampling,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
)
if is_torch_available():
import torch
if is_torchvision_available():
from torchvision.transforms import functional as F
class PerceptionLMFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
vision_input_type: str = "thumb+tile"
tile_size: int = 448
max_num_tiles: int = 36
@add_start_docstrings(
"Constructs a fast PerceptionLM image processor.",
)
class PerceptionLMImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
do_resize = True
do_center_crop = False
do_rescale = True
do_normalize = True
do_convert_rgb = True
size = {"width": 448, "height": 448} # for backward compatibility in tests
valid_kwargs = PerceptionLMFastImageProcessorKwargs
def __init__(self, **kwargs: Unpack[PerceptionLMFastImageProcessorKwargs]) -> None:
super().__init__(**kwargs)
@staticmethod
def _factors(n: int):
"""Return all factors of a number."""
return set(
reduce(
list.__add__,
([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
)
)
def _find_supported_aspect_ratios(self):
"""
This function computes all the allowed aspect ratios for a fixed
number of input chunks. The order of returned items matters for the result of `_fit_image_to_canvas` function.
If tie exists in `_fit_image_to_canvas`, the latter in `_find_supported_aspect_ratios` wins.
For example, with `num_tiles=5`, it will return:
{
0.2: [(1, 5)],
5.0: [(5, 1)],
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.3333333333333333: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
"""
asp_dict = {}
for chunk_size in range(self.max_num_tiles, 0, -1):
_factors = sorted(self._factors(chunk_size))
_asp_ratios = [(x, chunk_size // x) for x in _factors]
for ratio in _asp_ratios:
k = ratio[0] / ratio[1]
if k not in asp_dict:
asp_dict[k] = [ratio]
else:
asp_dict[k].append(ratio)
return asp_dict
def _get_image_height_width(
self, image_width: int, image_height: int, target_width: int, target_height: int
) -> tuple[int, int]:
"""
Given image width, height and target width, height for the canvas, return the dimensions of how the image would be resized
with aspect ratio preservation.
"""
scale = image_width / image_height
if scale > 1.0:
# Width is larger than height
# Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
rescaling_factor = min(target_width / image_width, target_height / image_height)
# Set new width to target width and height to the rescaled height.
new_w = rescaling_factor * image_width
new_h = math.floor(new_w / scale)
else:
# Height is larger than width
# Rescaling factor is the minimum of the two scaling factors. Else one side would be outside of the canvas.
rescaling_factor = min(target_width / image_width, target_height / image_height)
# Set new height to target height and width to the rescaled width.
new_h = rescaling_factor * image_height
new_w = math.floor(new_h * scale)
return new_w, new_h
def _fit_image_to_canvas(self, img_width: int, img_height: int, tile_size: int):
"""
Given an image width, height and target number of chunks this function will see if the image
can be fit into any of the canvases that can be build from arranging the tiles in a grid.
If the image can be fit onto several canvases, it will return the canvas where the shorter edge
of the image will be largest.
"""
# Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None.
optimal_canvas = None
optimal_image_width_height = None
scale = img_width / img_height
# Gather all potential supported image resolutions and iterate through them to find best match
potential_arrangements = [
item for sublist in self._find_supported_aspect_ratios().values() for item in sublist
]
for n_w, n_h in potential_arrangements:
# Compute the canvas size
canvas_width, canvas_height = n_w * tile_size, n_h * tile_size
# Check if image can fit into the canvas without downsampling
if canvas_width >= img_width and canvas_height >= img_height:
# If we did not find a good canvas yet, we will use the current one
if optimal_canvas is None:
# Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling
optimal_canvas = (n_w, n_h)
optimal_image_width_height = self._get_image_height_width(
image_width=img_width,
image_height=img_height,
target_width=n_w * tile_size,
target_height=n_h * tile_size,
)
else:
# If we already found an optimal canvas before, we will check if the shorter edge of the image will be larger than the current optimal canvas.
# This means we can potentially upsample the image resolution which is beneficial to performance.
image_width_height = self._get_image_height_width(
image_width=img_width,
image_height=img_height,
target_width=n_w * tile_size,
target_height=n_h * tile_size,
)
# Llama3V dynamic tiling. Priortize biggest canvas.
if (scale < 1.0 and (image_width_height[0] >= optimal_image_width_height[0])) or (
scale >= 1.0 and (image_width_height[1] >= optimal_image_width_height[1])
):
optimal_canvas = (n_w, n_h)
optimal_image_width_height = image_width_height
return optimal_canvas
def _find_closest_aspect_ratio(self, img_width: int, img_height: int, tile_size: int) -> tuple:
"""
Given an image width, height and target number of chunks
this function will find the closest supported aspect ratio.
"""
target_aspect_ratio = img_width / img_height
asp_dict = self._find_supported_aspect_ratios()
closest_aspect_ratio = None
if target_aspect_ratio >= 1:
closest_aspect_ratio = min(
[k for k in asp_dict.keys() if k <= target_aspect_ratio],
key=lambda x: abs(x - target_aspect_ratio),
)
tiles_given_aspect_ratio = asp_dict[closest_aspect_ratio]
# select largest width
return max(tiles_given_aspect_ratio, key=lambda x: x[0])
else:
closest_aspect_ratio = min(
[k for k in asp_dict.keys() if k > target_aspect_ratio],
key=lambda x: abs(1 / x - 1 / target_aspect_ratio),
)
tiles_given_aspect_ratio = asp_dict[closest_aspect_ratio]
# select largest height
return max(tiles_given_aspect_ratio, key=lambda x: x[1])
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
batch_size, num_channels, height, width = image.size()
image = image.view(batch_size, num_channels, nch, height // nch, ncw, width // ncw)
# Permute dimensions to reorder the axes
image = image.permute(0, 2, 4, 1, 3, 5).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(batch_size, ncw * nch, num_channels, height // nch, width // ncw)
return image
def resize(
self,
image: np.ndarray,
tile_size: int,
max_num_tiles: int,
resample: PILImageResampling = PILImageResampling.BICUBIC,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
height, width = get_image_size(image, channel_dim=input_data_format)
if max_num_tiles > 1:
aspect_ratio = self._fit_image_to_canvas(img_width=width, img_height=height, tile_size=tile_size)
if aspect_ratio is None:
# If we did not find a canvas, we have to find the closest aspect ratio and downsample the image
aspect_ratio = self._find_closest_aspect_ratio(img_width=width, img_height=height, tile_size=tile_size)
else:
aspect_ratio = (1, 1)
new_width, new_height = aspect_ratio[0] * tile_size, aspect_ratio[1] * tile_size
image = F.resize(image, (new_height, new_width), interpolation=resample)
return image, aspect_ratio
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
do_rescale: Optional[bool],
rescale_factor: Optional[Union[int, float]],
do_normalize: Optional[bool],
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
tile_size: int,
max_num_tiles: int,
return_tensors: Optional[Union[str, TensorType]],
disable_grouping: bool,
**kwargs: Unpack[PerceptionLMFastImageProcessorKwargs],
) -> BatchFeature:
# Group images by size for batched transformation
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
if self.vision_input_type == "thumb+tile":
thumbnails, _ = self.resize(stacked_images, tile_size, max_num_tiles=1)
images_for_tiling, (tiles_w, tiles_h) = self.resize(
stacked_images, tile_size, max_num_tiles=max_num_tiles
)
image_tiles = self._split(images_for_tiling, tiles_w, tiles_h)
stacked_images = torch.cat([thumbnails.unsqueeze(1), image_tiles], dim=1)
else: # vanilla single tile for low memory devices
stacked_images, _ = self.resize(stacked_images, tile_size, max_num_tiles=1)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
__all__ = ["PerceptionLMImageProcessorFast"]

View File

@ -0,0 +1,515 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/perception_lm/modular_perception_lm.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_perception_lm.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.generation.utils import GenerationMixin
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, can_return_tuple
from ..auto import AutoModel
from .configuration_perception_lm import PerceptionLMConfig
class AdaptiveAvgPooling(nn.Module):
def __init__(self, pooling_ratio=2):
super(AdaptiveAvgPooling, self).__init__()
self.pooling_ratio = pooling_ratio
def forward(self, hidden_states):
b, num_tokens, c = hidden_states.shape
h = int(math.sqrt(num_tokens))
if h * h != num_tokens:
raise ValueError(f"num_tokens {num_tokens} is expected to be a square number")
shape = (h // self.pooling_ratio, h // self.pooling_ratio)
hidden_states = hidden_states.permute(0, 2, 1).reshape(b, -1, h, h)
hidden_states = F.adaptive_avg_pool2d(hidden_states, shape)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
return hidden_states
class PerceptionLMMultiModalProjector(nn.Module):
def __init__(self, config: PerceptionLMConfig):
super().__init__()
input_size = config.vision_config.model_args["embed_dim"]
output_size = config.text_config.hidden_size
self.projector = nn.ModuleList(
[
nn.Linear(
in_features=input_size,
out_features=output_size,
bias=True,
),
nn.GELU(),
nn.Linear(
in_features=output_size,
out_features=output_size,
bias=True,
),
]
)
self.pooling = (
AdaptiveAvgPooling(config.projector_pooling_ratio) if config.projector_pooling_ratio > 1 else nn.Identity()
)
def forward(self, features):
features = features.permute(1, 0, 2) # NLD -> LND
for layer in self.projector:
features = layer(features)
features = features.permute(1, 0, 2) # LND -> NLD
features = self.pooling(features)
return features
@auto_docstring
class PerceptionLMPreTrainedModel(PreTrainedModel):
config_class = PerceptionLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_flex_attn = True
_supports_attention_backend = True
def _init_weights(self, module):
# important: this ported version of PerceptionLM isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/PerceptionLM/tree/main/perception_lm should serve for that purpose
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
@dataclass
@auto_docstring(
custom_intro="""
Base class for PerceptionLM outputs, with hidden states and attentions.
"""
)
class PerceptionLMModelOutputWithPast(BaseModelOutputWithPast):
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
@auto_docstring(
custom_intro="""
Base class for PerceptionLM causal language model (or autoregressive) outputs.
"""
)
class PerceptionLMCausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[list[torch.FloatTensor]] = None
hidden_states: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None
@auto_docstring
class PerceptionLMModel(PerceptionLMPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: PerceptionLMConfig):
super().__init__(config)
self.multi_modal_projector = PerceptionLMMultiModalProjector(config)
self.language_model = AutoModel.from_config(config.text_config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_image_features(
self,
pixel_values: torch.FloatTensor,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_tiles, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_tiles, num_patches, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values.flatten(0, 1))
image_outputs = image_outputs.last_hidden_state
if self.config.vision_use_cls_token:
image_outputs = image_outputs[:, 1:, :]
image_features = self.multi_modal_projector(image_outputs)
return image_features
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[tuple, PerceptionLMModelOutputWithPast]:
"""
Forward pass of the PerceptionLM model.
Args:
input_ids (`torch.LongTensor`, *optional*):
Indices of input sequence tokens in the vocabulary.
pixel_values (`torch.FloatTensor`, *optional*):
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
pixel_values_videos (`torch.FloatTensor`, *optional*):
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
attention_mask (`torch.Tensor`, *optional*):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor`, *optional*):
Indices of positions of each input sequence token in the position embeddings.
past_key_values (`list[torch.FloatTensor]`, *optional*):
Precomputed key and value hidden states for fast autoregressive generation.
inputs_embeds (`torch.FloatTensor`, *optional*):
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
use_cache (`bool`, *optional*):
Whether or not to use past key values to speed up decoding.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor`, *optional*):
Position indices for caching.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
Number of logits to keep.
**lm_kwargs:
Additional keyword arguments for the language model.
Returns:
[`PerceptionLMModelOutputWithPast`] or `tuple`:
Model outputs as a `PerceptionLMModelOutputWithPast` if `return_dict=True`, otherwise a tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both (pixel_values or pixel_values_videos) and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_features = None
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values.to(inputs_embeds),
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
self.check_mask_feature_size_match(special_image_mask, image_features)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
video_features = None
if pixel_values_videos is not None:
video_features = self.get_image_features(
pixel_values=pixel_values_videos.to(inputs_embeds),
)
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
self.check_mask_feature_size_match(special_video_mask, video_features)
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
video_features = video_features.to(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
return PerceptionLMModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
past_key_values=outputs.past_key_values,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=(video_features if pixel_values_videos is not None else None),
)
def check_mask_feature_size_match(self, media_mask, media_features):
media_token_count = media_mask.sum()
media_feature_size = media_features.size()[:-1].numel()
if media_token_count != media_feature_size:
raise ValueError(
f"The number of tokens in the media mask ({media_token_count}) does not match the number of features in the media features ({media_feature_size}. Features shape: {media_features.shape})"
)
@auto_docstring
class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: PerceptionLMConfig, **super_kwargs):
super().__init__(config, **super_kwargs)
self.model = PerceptionLMModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
def set_input_embeddings(self, new_embeddings):
self.model.set_input_embeddings(new_embeddings)
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
pixel_values_videos=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
return model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]:
"""
Forward pass for the PerceptionLMForConditionalGeneration model.
Args:
input_ids (`torch.LongTensor`, *optional*):
Indices of input sequence tokens in the vocabulary.
pixel_values (`torch.FloatTensor`, *optional*):
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
pixel_values_videos (`torch.FloatTensor`, *optional*):
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
attention_mask (`torch.Tensor`, *optional*):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor`, *optional*):
Indices of positions of each input sequence token in the position embeddings.
past_key_values (`list[torch.FloatTensor]`, *optional*):
Precomputed key and value hidden states for fast autoregressive generation.
inputs_embeds (`torch.FloatTensor`, *optional*):
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
labels (`torch.LongTensor`, *optional*):
Labels for computing the language modeling loss.
use_cache (`bool`, *optional*):
Whether or not to use past key values to speed up decoding.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor`, *optional*):
Position indices for caching.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
Number of logits to keep.
**lm_kwargs:
Additional keyword arguments for the language model.
Returns:
[`PerceptionLMCausalLMOutputWithPast`] or `tuple`:
Model outputs as a `PerceptionLMCausalLMOutputWithPast` if `return_dict=True`, otherwise a tuple.
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
return PerceptionLMCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
video_hidden_states=outputs.video_hidden_states,
)
__all__ = ["PerceptionLMForConditionalGeneration", "PerceptionLMPreTrainedModel", "PerceptionLMModel"]

View File

@ -0,0 +1,444 @@
# coding=utf-8
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch PerceptionLM model."""
import math
from typing import Optional, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.generation.utils import GenerationMixin
from ...utils import (
auto_docstring,
can_return_tuple,
logging,
)
from ..auto import AutoModel
from ..llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaModel,
LlavaModelOutputWithPast,
LlavaPreTrainedModel,
)
from .configuration_perception_lm import PerceptionLMConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PerceptionLMConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/Perception-LM-1B"
class AdaptiveAvgPooling(nn.Module):
def __init__(self, pooling_ratio=2):
super(AdaptiveAvgPooling, self).__init__()
self.pooling_ratio = pooling_ratio
def forward(self, hidden_states):
b, num_tokens, c = hidden_states.shape
h = int(math.sqrt(num_tokens))
if h * h != num_tokens:
raise ValueError(f"num_tokens {num_tokens} is expected to be a square number")
shape = (h // self.pooling_ratio, h // self.pooling_ratio)
hidden_states = hidden_states.permute(0, 2, 1).reshape(b, -1, h, h)
hidden_states = F.adaptive_avg_pool2d(hidden_states, shape)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
return hidden_states
class PerceptionLMMultiModalProjector(nn.Module):
def __init__(self, config: PerceptionLMConfig):
super().__init__()
input_size = config.vision_config.model_args["embed_dim"]
output_size = config.text_config.hidden_size
self.projector = nn.ModuleList(
[
nn.Linear(
in_features=input_size,
out_features=output_size,
bias=True,
),
nn.GELU(),
nn.Linear(
in_features=output_size,
out_features=output_size,
bias=True,
),
]
)
self.pooling = (
AdaptiveAvgPooling(config.projector_pooling_ratio) if config.projector_pooling_ratio > 1 else nn.Identity()
)
def forward(self, features):
features = features.permute(1, 0, 2) # NLD -> LND
for layer in self.projector:
features = layer(features)
features = features.permute(1, 0, 2) # LND -> NLD
features = self.pooling(features)
return features
class PerceptionLMPreTrainedModel(LlavaPreTrainedModel):
base_model_prefix = "model"
class PerceptionLMModelOutputWithPast(LlavaModelOutputWithPast):
video_hidden_states: Optional[torch.FloatTensor] = None
class PerceptionLMCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
video_hidden_states: Optional[torch.FloatTensor] = None
@auto_docstring
class PerceptionLMModel(LlavaModel):
def __init__(self, config: PerceptionLMConfig):
super().__init__(config)
del self.vision_tower
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = PerceptionLMMultiModalProjector(config)
self.language_model = AutoModel.from_config(config.text_config)
def get_image_features(
self,
pixel_values: torch.FloatTensor,
**kwargs,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_tiles, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_tiles, num_patches, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values.flatten(0, 1))
image_outputs = image_outputs.last_hidden_state
if self.config.vision_use_cls_token:
image_outputs = image_outputs[:, 1:, :]
image_features = self.multi_modal_projector(image_outputs)
return image_features
def check_mask_feature_size_match(self, media_mask, media_features):
media_token_count = media_mask.sum()
media_feature_size = media_features.size()[:-1].numel()
if media_token_count != media_feature_size:
raise ValueError(
f"The number of tokens in the media mask ({media_token_count}) does not match the number of features in the media features ({media_feature_size}. Features shape: {media_features.shape})"
)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[tuple, PerceptionLMModelOutputWithPast]:
"""
Forward pass of the PerceptionLM model.
Args:
input_ids (`torch.LongTensor`, *optional*):
Indices of input sequence tokens in the vocabulary.
pixel_values (`torch.FloatTensor`, *optional*):
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
pixel_values_videos (`torch.FloatTensor`, *optional*):
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
attention_mask (`torch.Tensor`, *optional*):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor`, *optional*):
Indices of positions of each input sequence token in the position embeddings.
past_key_values (`list[torch.FloatTensor]`, *optional*):
Precomputed key and value hidden states for fast autoregressive generation.
inputs_embeds (`torch.FloatTensor`, *optional*):
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
use_cache (`bool`, *optional*):
Whether or not to use past key values to speed up decoding.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor`, *optional*):
Position indices for caching.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
Number of logits to keep.
**lm_kwargs:
Additional keyword arguments for the language model.
Returns:
[`PerceptionLMModelOutputWithPast`] or `tuple`:
Model outputs as a `PerceptionLMModelOutputWithPast` if `return_dict=True`, otherwise a tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both (pixel_values or pixel_values_videos) and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_features = None
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values.to(inputs_embeds),
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
self.check_mask_feature_size_match(special_image_mask, image_features)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
video_features = None
if pixel_values_videos is not None:
video_features = self.get_image_features(
pixel_values=pixel_values_videos.to(inputs_embeds),
)
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
self.check_mask_feature_size_match(special_video_mask, video_features)
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
video_features = video_features.to(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
return PerceptionLMModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
past_key_values=outputs.past_key_values,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=(video_features if pixel_values_videos is not None else None),
)
@auto_docstring
class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: PerceptionLMConfig, **super_kwargs):
super().__init__(config, **super_kwargs)
self.model = PerceptionLMModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
def set_input_embeddings(self, new_embeddings):
self.model.set_input_embeddings(new_embeddings)
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
pixel_values_videos=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
if cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
model_inputs["pixel_values_videos"] = pixel_values_videos
return model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[tuple, PerceptionLMCausalLMOutputWithPast]:
"""
Forward pass for the PerceptionLMForConditionalGeneration model.
Args:
input_ids (`torch.LongTensor`, *optional*):
Indices of input sequence tokens in the vocabulary.
pixel_values (`torch.FloatTensor`, *optional*):
Input image tensor of shape `(batch_size, num_tiles, channels, height, width)`.
pixel_values_videos (`torch.FloatTensor`, *optional*):
Input video tensor of shape `(batch_size, num_frames, channels, height, width)`.
attention_mask (`torch.Tensor`, *optional*):
Mask to avoid performing attention on padding token indices.
position_ids (`torch.LongTensor`, *optional*):
Indices of positions of each input sequence token in the position embeddings.
past_key_values (`list[torch.FloatTensor]`, *optional*):
Precomputed key and value hidden states for fast autoregressive generation.
inputs_embeds (`torch.FloatTensor`, *optional*):
Optionally, instead of passing `input_ids`, you can choose to directly pass an embedded representation.
labels (`torch.LongTensor`, *optional*):
Labels for computing the language modeling loss.
use_cache (`bool`, *optional*):
Whether or not to use past key values to speed up decoding.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor`, *optional*):
Position indices for caching.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
Number of logits to keep.
**lm_kwargs:
Additional keyword arguments for the language model.
Returns:
[`PerceptionLMCausalLMOutputWithPast`] or `tuple`:
Model outputs as a `PerceptionLMCausalLMOutputWithPast` if `return_dict=True`, otherwise a tuple.
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
return PerceptionLMCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
video_hidden_states=outputs.video_hidden_states,
)
__all__ = [
"PerceptionLMForConditionalGeneration",
"PerceptionLMPreTrainedModel",
"PerceptionLMModel",
]

View File

@ -0,0 +1,207 @@
# coding=utf-8
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for PerceptionLM.
"""
from typing import Iterable, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...video_utils import VideoInput
logger = logging.get_logger(__name__)
class PerceptionLMProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
}
class PerceptionLMProcessor(ProcessorMixin):
r"""
Constructs a PerceptionLM processor which wraps a PerceptionLM image processor, a PerceptionLM video processor, and a tokenizer into a single processor.
[`PerceptionLMProcessor`] offers all the functionalities of [`PerceptionLMImageProcessorFast`], [`PerceptionLMVideoProcessor`], and the tokenizer (e.g. [`LlamaTokenizerFast`]). See the
[`~PerceptionLMProcessor.__call__`] and [`~PerceptionLMProcessor.decode`] for more information.
Args:
video_processor ([`PerceptionLMVideoProcessor`], *optional*):
The video processor to process video inputs.
image_processor ([`PerceptionLMImageProcessorFast`], *optional*):
The image processor to process image inputs.
tokenizer ([`LlamaTokenizerFast`] or similar, *optional*):
The tokenizer to process text inputs.
patch_size (`int`, *optional*):
Patch size from the vision tower.
chat_template (`str`, *optional*):
A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
pooling_ratio (`int`, *optional*, defaults to 2):
Pooling ratio for vision tokens. If not 1, 2D adaptive pooling is applied over projected vision tokens.
"""
attributes = ["video_processor", "image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
video_processor_class = "AutoVideoProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
video_processor=None,
image_processor=None,
tokenizer=None,
patch_size=None,
chat_template=None,
pooling_ratio=2,
**kwargs,
):
self.patch_size = patch_size
self.pooling_ratio = pooling_ratio
self.image_token = tokenizer.image_token
self.video_token = tokenizer.video_token
self.image_token_id = tokenizer.image_token_id
self.video_token_id = tokenizer.video_token_id
super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
audio=None,
videos: VideoInput = None,
**kwargs: Unpack[PerceptionLMProcessorKwargs],
) -> BatchFeature:
"""
Prepares a batch containing one or more sequences of text and/or images and/or videos.
If `text` is provided, it is tokenized using the tokenizer.
If `images` is provided, they are processed using the image processor.
If `videos` is provided, they are processed using the video processor.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
The image or batch of images to be processed. Each image can be a PIL image, NumPy array, or PyTorch tensor.
Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, *optional*):
The sequence or batch of sequences to be tokenized. Each sequence can be a string.
videos (`Any`, *optional*):
The video or batch of videos to be processed.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is provided.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is provided).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is provided.
- **pixel_values_videos** -- Video pixel values to be fed to a model. Returned when `videos` is provided.
"""
if text is None:
raise ValueError(
"You have to specify at least `text` input. Optionally, you can also specify `images` or `videos`."
)
output_kwargs = self._merge_kwargs(
PerceptionLMProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
else:
image_inputs = {}
if videos is not None:
videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
else:
videos_inputs = {}
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
# try to expand inputs in processing if we have the necessary parts
prompt_strings = []
pixel_values = iter(image_inputs.get("pixel_values", []))
pixel_values_videos = iter(videos_inputs.get("pixel_values_videos", []))
for sample in text:
# Replace the media token with the expanded media token sequence
sample = self._expand_media_tokens(sample, self.tokenizer.image_token, pixel_values)
sample = self._expand_media_tokens(sample, self.tokenizer.video_token, pixel_values_videos)
prompt_strings.append(sample)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image", "video"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _expand_media_tokens(self, sample, media_token: str, media_iter: Iterable):
media_count = sample.count(media_token)
if media_count > 0:
media_list = [next(media_iter) for _ in range(media_count)]
sample_splits = sample.split(media_token)
media_token_list = []
for media in media_list:
height, width = get_image_size(to_numpy_array(media))
num_tiles = media.shape[0]
num_media_tokens = (
(height // self.patch_size // self.pooling_ratio)
* (width // self.patch_size // self.pooling_ratio)
* num_tiles
)
media_token_list.append(num_media_tokens)
sample = ""
for i, num_media_tokens in enumerate(media_token_list):
sample += sample_splits[i]
sample += media_token * num_media_tokens
sample += sample_splits[-1]
return sample
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PerceptionLMTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PerceptionLMTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["PerceptionLMProcessor"]

View File

@ -0,0 +1,53 @@
# coding=utf-8
# Copyright 2025 Meta Platforms, Inc. and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video processor class for PerceptionLM."""
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
)
from ...processing_utils import Unpack, VideosKwargs
from ...utils import is_vision_available
from ...utils.import_utils import requires
from ...video_processing_utils import (
BaseVideoProcessor,
)
if is_vision_available():
from ...image_utils import PILImageResampling
class PerceptionLMFastVideoProcessorInitKwargs(VideosKwargs): ...
@requires(backends=("torchvision",))
class PerceptionLMVideoProcessor(BaseVideoProcessor):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 448, "width": 448}
do_resize = True
do_center_crop = False
do_rescale = True
do_normalize = True
do_convert_rgb = True
valid_kwargs = PerceptionLMFastVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[PerceptionLMFastVideoProcessorInitKwargs]):
super().__init__(**kwargs)
__all__ = ["PerceptionLMVideoProcessor"]

View File

View File

@ -0,0 +1,224 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
if is_torchvision_available():
from transformers import PerceptionLMImageProcessorFast
class PerceptionLMImageProcessingTester:
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
tile_size=16,
do_normalize=True,
image_mean=IMAGENET_STANDARD_MEAN,
image_std=IMAGENET_STANDARD_STD,
do_convert_rgb=True,
max_num_tiles=4,
vision_input_type="thumb+tile",
resample=Image.Resampling.BICUBIC, # dummy value
size={"shortest_edge": 20}, # dummy value
):
super().__init__()
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.tile_size = tile_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.max_num_tiles = max_num_tiles
self.vision_input_type = vision_input_type
self.resample = resample
self.size = size
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"tile_size": self.tile_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
"max_num_tiles": self.max_num_tiles,
"vision_input_type": self.vision_input_type,
"resample": self.resample,
"size": self.size,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.crop_size["height"], self.crop_size["width"]
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class PerceptionLMImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
fast_image_processing_class = PerceptionLMImageProcessorFast if is_torchvision_available() else None
test_slow_image_processor = False
def setUp(self):
super().setUp()
self.image_processor_tester = PerceptionLMImageProcessingTester(self)
@property
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "tile_size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "max_num_tiles"))
self.assertTrue(hasattr(image_processing, "vision_input_type"))
def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.tile_size, 16)
self.assertEqual(image_processor.max_num_tiles, 4)
self.assertEqual(image_processor.vision_input_type, "thumb+tile")
image_processor = image_processing_class.from_dict(
self.image_processor_dict, tile_size=42, max_num_tiles=9
)
self.assertEqual(image_processor.tile_size, 42)
self.assertEqual(image_processor.max_num_tiles, 9)
self.assertEqual(image_processor.vision_input_type, "thumb+tile")
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 5, 3, 16, 16)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 5, 3, 16, 16)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 5, 3, 16, 16)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 5, 3, 16, 16)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 5, 3, 16, 16)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 5, 3, 16, 16)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@unittest.skip(reason="PerceptionLMImageProcessor doesn't treat 4 channel PIL and numpy consistently yet")
def test_call_numpy_4_channels(self):
pass
def test_nested_input(self):
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
# Test batched as a list of images
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 5, 3, 16, 16)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched as a nested list of images, where each sublist is one batch
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 5, 3, 16, 16)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
# Image processor should return same pixel values, independently of ipnut format
self.assertTrue((encoded_images_nested == encoded_images).all())

View File

@ -0,0 +1,474 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch PerceptionLM model."""
import unittest
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
PerceptionLMConfig,
PerceptionLMForConditionalGeneration,
PerceptionLMModel,
is_torch_available,
)
from transformers.testing_utils import (
cleanup,
require_bitsandbytes,
require_torch,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
class PerceptionLMVisionText2TextModelTester:
def __init__(
self,
parent,
image_token_id=0,
video_token_id=2,
seq_length=7,
tie_word_embeddings=True,
projector_pooling_ratio=1,
text_config={
"model_type": "llama",
"seq_length": 7,
"is_training": True,
"use_input_mask": True,
"use_token_type_ids": False,
"use_labels": True,
"vocab_size": 99,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 16,
"type_sequence_label_size": 2,
"initializer_range": 0.02,
"num_labels": 3,
"num_choices": 4,
"pad_token_id": 1,
},
is_training=True,
vision_config={
"architecture": "vit_pe_core_large_patch14_336",
"model_args": {
"embed_dim": 64,
"img_size": (14, 14),
"depth": 2,
"global_pool": "",
"use_post_transformer_norm": False,
"init_values": 0.1,
"ref_feat_shape": (1, 1),
},
},
):
self.parent = parent
self.image_token_id = image_token_id
self.video_token_id = video_token_id
self.text_config = text_config
self.vision_config = vision_config
self.pad_token_id = text_config["pad_token_id"]
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
self.is_training = is_training
self.tie_word_embeddings = tie_word_embeddings
self.batch_size = 3
self.num_tiles = 1
self.num_frames = 1
self.num_channels = 3
self.image_size = self.vision_config["model_args"]["img_size"][0]
self.num_image_tokens = (self.vision_config["model_args"]["img_size"][0] // 14) ** 2
self.num_video_tokens = (self.vision_config["model_args"]["img_size"][0] // 14) ** 2
self.seq_length = seq_length + self.num_image_tokens
self.encoder_seq_length = self.seq_length
def get_config(self):
return PerceptionLMConfig(
text_config=self.text_config,
vision_config=self.vision_config,
vision_use_cls_token=True,
image_token_id=self.image_token_id,
video_token_id=self.video_token_id,
tie_word_embeddings=self.tie_word_embeddings,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.num_tiles,
self.num_channels,
self.vision_config["model_args"]["img_size"][0],
self.vision_config["model_args"]["img_size"][1],
]
)
pixel_values_videos = floats_tensor(
[
self.batch_size,
self.num_frames,
self.num_channels,
self.vision_config["model_args"]["img_size"][0],
self.vision_config["model_args"]["img_size"][1],
]
)
config = self.get_config()
return config, pixel_values, pixel_values_videos
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_values_videos = self.prepare_config_and_inputs()
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 2) + 2
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
input_ids[input_ids == config.image_token_id] = self.pad_token_id
input_ids[input_ids == config.video_token_id] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = config.image_token_id
input_ids[:, self.num_image_tokens : self.num_video_tokens + self.num_image_tokens] = config.video_token_id
inputs_dict = {
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class PerceptionLMForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Model tester for `PerceptionLMForConditionalGeneration`.
"""
all_model_classes = (
(
PerceptionLMModel,
PerceptionLMForConditionalGeneration,
)
if is_torch_available()
else ()
)
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = PerceptionLMVisionText2TextModelTester(self)
common_properties = [
"image_token_id",
"video_token_id",
]
self.config_tester = ConfigTester(
self,
config_class=PerceptionLMConfig,
has_text_modality=False,
common_properties=common_properties,
)
def test_config(self):
self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["pixel_values_videos"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["pixel_values_videos"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images doesn't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class == PerceptionLMModel:
continue
model = model_class(config).to(torch_device)
_ = model(**input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = input_dict["input_ids"][:1]
pixel_values = input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values)
def test_training(self):
self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else ()
super().test_training()
def test_training_gradient_checkpointing(self):
self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else ()
super().test_training_gradient_checkpointing()
def test_training_gradient_checkpointing_use_reentrant(self):
self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else ()
super().test_training_gradient_checkpointing_use_reentrant()
def test_training_gradient_checkpointing_use_reentrant_false(self):
self.all_model_classes = (PerceptionLMForConditionalGeneration,) if is_torch_available() else ()
super().test_training_gradient_checkpointing_use_reentrant_false()
@unittest.skip(reason="Timm Eva (PE) weights cannot be fully constructed in _init_weights")
def test_can_init_all_missing_weights(self):
pass
@unittest.skip(reason="Timm Eva (PE) weights cannot be fully constructed in _init_weights")
def test_initialization(self):
pass
@unittest.skip(
reason="PE/TIMM's attention implementation is self configured and won't raise ValueError on global attention implementation."
)
def test_flash_attn_2_can_dispatch_composite_models(self):
pass
@unittest.skip(
"VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("ViT PE cannot be tested with meta device")
def test_can_be_initialized_on_meta(self):
pass
@unittest.skip("ViT PE cannot be tested with meta device")
def test_can_load_with_meta_device_context_manager(self):
pass
@unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM")
def test_generate_from_inputs_embeds_0_greedy(self):
pass
@unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM")
def test_generate_from_inputs_embeds_1_beam_search(self):
pass
@unittest.skip("Specifying both inputs_embeds and pixel_values are not supported for PerceptionLM")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
## Skip flash attention releated tests below
## correct configuration:
## from_pretrained(model_id, attn_implementation={"text_config": "flash_attention_2", "vision_config": "eager"}
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
def test_flash_attn_2_fp32_ln(self):
pass
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
def test_flash_attn_2_from_config(self):
pass
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
def test_eager_matches_sdpa_generate_with_dynamic_cache(self):
pass
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
def test_eager_matches_sdpa_generate(self):
pass
@unittest.skip("Flash attn test is not configured correctly as we need to configure vision/timm model to 'eager'.")
def test_flash_attn_2_inference_equivalence(self):
pass
TEST_MODEL_PATH = "shumingh/plm_1b_hf"
@require_torch
class PerceptionLMForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH)
self.image_file = hf_hub_download(
repo_id="shumingh/perception_lm_test_images",
filename="14496_0.PNG",
repo_type="dataset",
)
self.video_file = hf_hub_download(
repo_id="shumingh/perception_lm_test_videos",
filename="GUWR5TyiY-M_000012_000022.mp4",
repo_type="dataset",
)
self.conversation1 = [
{
"role": "user",
"content": [
{"type": "image", "url": self.image_file},
{"type": "text", "text": "Describe the bar plot in the image."},
],
}
]
self.conversation2 = [
{
"role": "user",
"content": [
{
"type": "video",
"url": self.video_file,
},
{"type": "text", "text": "Can you describe the video in detail?"},
],
}
]
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@slow
@require_bitsandbytes
def test_small_model_integration_test(self):
model = PerceptionLMForConditionalGeneration.from_pretrained(
TEST_MODEL_PATH, load_in_4bit=True, cache_dir="./"
)
inputs = self.processor.apply_chat_template(
[self.conversation1],
num_frames=32,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
video_load_backend="decord",
padding=True,
padding_side="left",
).to(torch_device)
generate_ids = model.generate(**inputs, max_new_tokens=18)
input_length = inputs["input_ids"].shape[1]
generate_ids_without_inputs = generate_ids[:, input_length:]
EXPECTED_DECODED_TEXT = "The bar plot displays the values of four categories: step, horror, mood, and lumber" # fmt: skip
self.assertEqual(
self.processor.decode(generate_ids_without_inputs[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_batched(self):
model = PerceptionLMForConditionalGeneration.from_pretrained(TEST_MODEL_PATH, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH)
inputs = processor.apply_chat_template(
[self.conversation1, self.conversation2],
num_frames=32,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
video_load_backend="decord",
padding=True,
padding_side="left",
).to(torch_device)
generate_ids = model.generate(**inputs, max_new_tokens=18)
input_length = inputs["input_ids"].shape[1]
generate_ids_without_inputs = generate_ids[:, input_length:]
EXPECTED_DECODED_TEXT = ['The bar plot displays the values of four categories: step, horror, mood, and lumber', 'The video shows a group of people in green shirts and white shorts performing a jump rope routine'] # fmt: skip
self.assertEqual(
processor.batch_decode(generate_ids_without_inputs, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_generation_no_images(self):
# model_id = "facebook/Perception-LM-1B"
model = PerceptionLMForConditionalGeneration.from_pretrained(TEST_MODEL_PATH, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(TEST_MODEL_PATH)
# Prepare inputs with no images
inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device)
# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)

View File

@ -0,0 +1,145 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import shutil
import tempfile
import unittest
from transformers import (
AutoProcessor,
AutoTokenizer,
PerceptionLMProcessor,
)
from transformers.testing_utils import require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import PerceptionLMImageProcessorFast, PerceptionLMVideoProcessor
if is_torch_available():
import torch
TEST_MODEL_PATH = "shumingh/plm_1b_hf"
@require_vision
class PerceptionLMProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = PerceptionLMProcessor
@classmethod
def setUpClass(cls):
cls.tmpdirname = tempfile.mkdtemp()
image_processor = PerceptionLMImageProcessorFast(
tile_size=448, max_num_tiles=4, vision_input_type="thumb+tile"
)
video_processor = PerceptionLMVideoProcessor()
tokenizer = AutoTokenizer.from_pretrained(TEST_MODEL_PATH)
tokenizer.add_special_tokens({"additional_special_tokens": ["<|image|>", "<|video|>"]})
processor_kwargs = cls.prepare_processor_dict()
processor = PerceptionLMProcessor(
image_processor=image_processor, video_processor=video_processor, tokenizer=tokenizer, **processor_kwargs
)
processor.save_pretrained(cls.tmpdirname)
cls.image_token_id = processor.image_token_id
cls.video_token_id = processor.video_token_id
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
@staticmethod
def prepare_processor_dict():
return {
"chat_template": CHAT_TEMPLATE,
"patch_size": 14,
"pooling_ratio": 2,
} # fmt: skip
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
# chat templates aren't serialized to json in processors
self.assertFalse("chat_template" in processor_dict_loaded.keys())
# they have to be saved as separate file and loaded back from that file
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
# Important to check with non square image
image = torch.randn((1, 3, 450, 500))
# 5 tiles (thumbnail tile + 4 tiles)
# 448/patch_size/pooling_ratio = 16 => 16*16 tokens per tile
expected_image_tokens = 16 * 16 * 5
image_token_index = processor.image_token_id
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)
CHAT_TEMPLATE = (
"{{- bos_token }}"
"{%- if messages[0]['role'] == 'system' -%}"
" {%- set system_message = messages[0]['content']|trim %}\n"
" {%- set messages = messages[1:] %}\n"
"{%- else %}"
" {%- set system_message = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.' %}"
"{%- endif %}"
"{{- '<|start_header_id|>system<|end_header_id|>\\n\\n' }}"
"{{- system_message }}"
"{{- '<|eot_id|>' }}"
"{%- for message in messages %}"
"{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}"
"{%- for content in message['content'] | selectattr('type', 'equalto', 'image') %}"
"{{ '<|image|>' }}"
"{%- endfor %}"
"{%- for content in message['content'] | selectattr('type', 'equalto', 'video') %}"
"{{ '<|video|>' }}"
"{%- endfor %}"
"{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %}"
"{{- content['text'] | trim }}"
"{%- endfor %}"
"{{'<|eot_id|>' }}"
"{%- endfor %}"
"{%- if add_generation_prompt %}"
"{{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}"
"{%- endif %}"
)

View File

@ -0,0 +1,127 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_video_processing_common import VideoProcessingTestMixin, prepare_video_inputs
if is_torch_available():
pass
if is_vision_available():
if is_torchvision_available():
from transformers import PerceptionLMVideoProcessor
class PerceptionLMVideoProcessingTester:
def __init__(
self,
parent,
batch_size=5,
num_frames=8,
num_channels=3,
min_resolution=30,
max_resolution=80,
do_resize=True,
size=None,
do_center_crop=True,
crop_size=None,
do_normalize=True,
image_mean=IMAGENET_STANDARD_MEAN,
image_std=IMAGENET_STANDARD_STD,
do_convert_rgb=True,
):
size = size if size is not None else {"height": 20, "width": 20}
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_frames = num_frames
self.num_channels = num_channels
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
def prepare_video_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_center_crop": self.do_center_crop,
"crop_size": self.crop_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
}
def expected_output_video_shape(self, images):
return self.num_frames, self.num_channels, self.crop_size["height"], self.crop_size["width"]
def prepare_video_inputs(self, equal_resolution=False, return_tensors="pil"):
videos = prepare_video_inputs(
batch_size=self.batch_size,
num_frames=self.num_frames,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
return_tensors=return_tensors,
)
return videos
@require_torch
@require_vision
class PerceptionLMVideoProcessingTest(VideoProcessingTestMixin, unittest.TestCase):
fast_video_processing_class = PerceptionLMVideoProcessor if is_torchvision_available() else None
def setUp(self):
super().setUp()
self.video_processor_tester = PerceptionLMVideoProcessingTester(self)
@property
def video_processor_dict(self):
return self.video_processor_tester.prepare_video_processor_dict()
def test_video_processor_properties(self):
video_processing = self.fast_video_processing_class(**self.video_processor_dict)
self.assertTrue(hasattr(video_processing, "do_resize"))
self.assertTrue(hasattr(video_processing, "size"))
self.assertTrue(hasattr(video_processing, "do_center_crop"))
self.assertTrue(hasattr(video_processing, "center_crop"))
self.assertTrue(hasattr(video_processing, "do_normalize"))
self.assertTrue(hasattr(video_processing, "image_mean"))
self.assertTrue(hasattr(video_processing, "image_std"))
self.assertTrue(hasattr(video_processing, "do_convert_rgb"))
def test_video_processor_from_dict_with_kwargs(self):
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict)
self.assertEqual(video_processor.size, {"height": 20, "width": 20})
self.assertEqual(video_processor.crop_size, {"height": 18, "width": 18})
video_processor = self.fast_video_processing_class.from_dict(self.video_processor_dict, size=42, crop_size=84)
self.assertEqual(video_processor.size, {"height": 42, "width": 42})
self.assertEqual(video_processor.crop_size, {"height": 84, "width": 84})