Add Phi4 multimodal (#36939)

* raw start

* update

* update

* add to imports

* update

* up

* simplify configs

* clean configs

* style

* typos

* Update convert_phi4_multimodal_weights_to_hf.py

* Update convert_phi4_multimodal_weights_to_hf.py

* fix

* up

* up

* up

* Update convert_phi4_multimodal_weights_to_hf.py

* Update convert_phi4_multimodal_weights_to_hf.py

* up

* up

* up

* Update feature_extraction_phi4_multimodal.py

* up

* up

* up

* up

* up

* simplify configs

* typo

* cut code

* typo

* typo

* typo

* re

* typo

* up

* up

* up

* add tests

* fix

* fix

* Update test_modeling_phi4_multimodal.py

* up

* Update test_modeling_phi4_multimodal.py

* doc

* fix

* up

* up

* up

* up

* up

* up

* simplify

* up

* simplify

* config docstrings

* cleanup

* clean

* typo

* typo

* fix

* Update phi4_multimodal.md

* fix

* fix

* Update test_modeling_phi4_multimodal.py

* update

* simplify reshapes and permutes

* up

* simplify special tokens

* simplify processor a lot

* Update processing_phi4_multimodal.py

* Update processing_phi4_multimodal.py

* switch to fast processor

* image processor

* Update image_processing_phi4_multimodal_fast.py

* add lora extraction to converter

* Update convert_phi4_multimodal_weights_to_hf.py

* Update __init__.py

* add AudioInput type in audio_utils

* rewrite feature_extraction: support torch batched FFT

* input_audio_embeds -> audio_input_features, input_image_embeds -> image_pixel_values

* test update

* not mono channel warning update

* remove auto maps from processor

* kargs dispatch in processor

* simplify kwargs dispatch

* simplify merging

* remove default sampling rate

* style

* Update test_modeling_phi4_multimodal.py

* update doc

* doc

* torch only feature extractor

* make fake tokens adjustable

* Update feature_extraction_phi4_multimodal.py

* fix

* Update processing_phi4_multimodal.py

* simplify mask

* last touch

* fix copies

* style

* Update audio_utils.py

* style

* Update feature_extraction_phi4_multimodal.py

* Update __init__.py

* docstrings

* copies

* fix all checks

* back to fix-copies

* trigger CIs

* Update feature_extraction_phi4_multimodal.py

* improve tests with multimodal inputs

* trigger CIs

---------

Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
This commit is contained in:
Cyril Vallez 2025-03-25 09:55:21 +01:00 committed by GitHub
parent 47e5432805
commit 4303d88c09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 6380 additions and 1 deletions

View File

@ -583,6 +583,8 @@
title: Phi
- local: model_doc/phi3
title: Phi-3
- local: model_doc/phi4_multimodal
title: Phi4 Multimodal
- local: model_doc/phimoe
title: PhiMoE
- local: model_doc/phobert

View File

@ -0,0 +1,149 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Phi4 Multimodal
## Overview
Phi4 Multimodal is a lightweight open multimodal foundation model that leverages the language, vision, and speech research and datasets used for Phi-3.5 and 4.0 models. The model processes text, image, and audio inputs, generating text outputs, and comes with 128K token context length. The model underwent an enhancement process, incorporating both supervised fine-tuning, direct preference optimization and RLHF (Reinforcement Learning from Human Feedback) to support precise instruction adherence and safety measures. The languages that each modal supports are the following:
- Text: Arabic, Chinese, Czech, Danish, Dutch, English, Finnish, French, German, Hebrew, Hungarian, Italian, Japanese, Korean, Norwegian, Polish, Portuguese, Russian, Spanish, Swedish, Thai, Turkish, Ukrainian
- Vision: English
- Audio: English, Chinese, German, French, Italian, Japanese, Spanish, Portuguese
This model was contributed by [Cyril Vallez](https://huggingface.co/cyrilvallez). The most recent code can be
found [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py).
## Usage tips
`Phi4-multimodal-instruct` can be found on the [Huggingface Hub](https://huggingface.co/microsoft/Phi-4-multimodal-instruct)
In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).
```python
import requests
import torch
import os
import io
from PIL import Image
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from urllib.request import urlopen
# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"
device = "cuda:0"
# Load model and processor
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, torch_dtype=torch.float16)
# Optional: load the adapters (note that without them, the base model will very likely not work well)
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})
# Define prompt structure
user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'
# Part 1: Image Processing
model.set_adapter("vision") # if loaded, activate the vision adapter
print("\n--- IMAGE PROCESSING ---")
image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}'
print(f'>>> Prompt\n{prompt}')
# Download and open image
image = Image.open(requests.get(image_url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors='pt').to(device)
# Generate response
generate_ids = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=False,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(f'>>> Response\n{response}')
# Part 2: Audio Processing
model.set_adapter("speech") # if loaded, activate the speech adapter
print("\n--- AUDIO PROCESSING ---")
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation."
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}'
print(f'>>> Prompt\n{prompt}')
# Downlowd and open audio file
audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read()))
# Process with the model
inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device)
generate_ids = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=False,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(f'>>> Response\n{response}')
```
## Phi4MultimodalFeatureExtractor
[[autodoc]] Phi4MultimodalFeatureExtractor
## Phi4MultimodalImageProcessorFast
[[autodoc]] Phi4MultimodalImageProcessorFast
## Phi4MultimodalProcessor
[[autodoc]] Phi4MultimodalProcessor
## Phi4MultimodalAudioConfig
[[autodoc]] Phi4MultimodalAudioConfig
## Phi4MultimodalVisionConfig
[[autodoc]] Phi4MultimodalVisionConfig
## Phi4MultimodalConfig
[[autodoc]] Phi4MultimodalConfig
## Phi4MultimodalAudioModel
[[autodoc]] Phi4MultimodalAudioModel
## Phi4MultimodalVisionModel
[[autodoc]] Phi4MultimodalVisionModel
## Phi4MultimodalModel
[[autodoc]] Phi4MultimodalModel
- forward
## Phi4MultimodalForCausalLM
[[autodoc]] Phi4MultimodalForCausalLM
- forward

View File

@ -699,6 +699,13 @@ _import_structure = {
"models.persimmon": ["PersimmonConfig"],
"models.phi": ["PhiConfig"],
"models.phi3": ["Phi3Config"],
"models.phi4_multimodal": [
"Phi4MultimodalAudioConfig",
"Phi4MultimodalConfig",
"Phi4MultimodalFeatureExtractor",
"Phi4MultimodalProcessor",
"Phi4MultimodalVisionConfig",
],
"models.phimoe": ["PhimoeConfig"],
"models.phobert": ["PhobertTokenizer"],
"models.pix2struct": [
@ -1348,6 +1355,7 @@ else:
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
_import_structure["models.phi4_multimodal"].append("Phi4MultimodalImageProcessorFast")
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
_import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast")
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
@ -2802,6 +2810,17 @@ else:
"LlavaNextPreTrainedModel",
]
)
_import_structure["models.phi4_multimodal"].extend(
[
"Phi4MultimodalForCausalLM",
"Phi4MultimodalPreTrainedModel",
"Phi4MultimodalAudioModel",
"Phi4MultimodalAudioPreTrainedModel",
"Phi4MultimodalModel",
"Phi4MultimodalVisionModel",
"Phi4MultimodalVisionPreTrainedModel",
]
)
_import_structure["models.llava_next_video"].extend(
[
"LlavaNextVideoForConditionalGeneration",
@ -5914,6 +5933,13 @@ if TYPE_CHECKING:
)
from .models.phi import PhiConfig
from .models.phi3 import Phi3Config
from .models.phi4_multimodal import (
Phi4MultimodalAudioConfig,
Phi4MultimodalConfig,
Phi4MultimodalFeatureExtractor,
Phi4MultimodalProcessor,
Phi4MultimodalVisionConfig,
)
from .models.phimoe import PhimoeConfig
from .models.phobert import PhobertTokenizer
from .models.pix2struct import (
@ -6587,6 +6613,7 @@ if TYPE_CHECKING:
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
from .models.phi4_multimodal import Phi4MultimodalImageProcessorFast
from .models.pixtral import PixtralImageProcessorFast
from .models.qwen2_vl import Qwen2VLImageProcessorFast
from .models.rt_detr import RTDetrImageProcessorFast
@ -8153,6 +8180,15 @@ if TYPE_CHECKING:
Phi3Model,
Phi3PreTrainedModel,
)
from .models.phi4_multimodal import (
Phi4MultimodalAudioModel,
Phi4MultimodalAudioPreTrainedModel,
Phi4MultimodalForCausalLM,
Phi4MultimodalModel,
Phi4MultimodalPreTrainedModel,
Phi4MultimodalVisionModel,
Phi4MultimodalVisionPreTrainedModel,
)
from .models.phimoe import (
PhimoeForCausalLM,
PhimoeForSequenceClassification,

View File

@ -17,11 +17,16 @@ and remove unnecessary dependencies.
"""
import warnings
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
AudioInput = Union[
np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"] # noqa: F821
]
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
"""
Convert frequency from hertz to mels.

View File

@ -212,6 +212,7 @@ from . import (
persimmon,
phi,
phi3,
phi4_multimodal,
phimoe,
phobert,
pix2struct,

View File

@ -235,6 +235,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("persimmon", "PersimmonConfig"),
("phi", "PhiConfig"),
("phi3", "Phi3Config"),
("phi4_multimodal", "Phi4MultimodalConfig"),
("phimoe", "PhimoeConfig"),
("pix2struct", "Pix2StructConfig"),
("pixtral", "PixtralVisionConfig"),
@ -587,6 +588,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("persimmon", "Persimmon"),
("phi", "Phi"),
("phi3", "Phi3"),
("phi4_multimodal", "Phi4Multimodal"),
("phimoe", "Phimoe"),
("phobert", "PhoBERT"),
("pix2struct", "Pix2Struct"),

View File

@ -78,6 +78,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("nat", "ViTFeatureExtractor"),
("owlvit", "OwlViTFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),
("pop2piano", "Pop2PianoFeatureExtractor"),
("regnet", "ConvNextFeatureExtractor"),

View File

@ -124,6 +124,7 @@ else:
("owlvit", ("OwlViTImageProcessor",)),
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("perceiver", ("PerceiverImageProcessor",)),
("phi4_multimodal", "Phi4MultimodalImageProcessorFast"),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("poolformer", ("PoolFormerImageProcessor",)),

View File

@ -218,6 +218,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("persimmon", "PersimmonModel"),
("phi", "PhiModel"),
("phi3", "Phi3Model"),
("phi4_multimodal", "Phi4MultimodalModel"),
("phimoe", "PhimoeModel"),
("pixtral", "PixtralVisionModel"),
("plbart", "PLBartModel"),
@ -566,6 +567,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("persimmon", "PersimmonForCausalLM"),
("phi", "PhiForCausalLM"),
("phi3", "Phi3ForCausalLM"),
("phi4_multimodal", "Phi4MultimodalForCausalLM"),
("phimoe", "PhimoeForCausalLM"),
("plbart", "PLBartForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),

View File

@ -91,6 +91,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("owlv2", "Owlv2Processor"),
("owlvit", "OwlViTProcessor"),
("paligemma", "PaliGemmaProcessor"),
("phi4_multimodal", "Phi4MultimodalProcessor"),
("pix2struct", "Pix2StructProcessor"),
("pixtral", "PixtralProcessor"),
("pop2piano", "Pop2PianoProcessor"),

View File

@ -0,0 +1,32 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_phi4_multimodal import *
from .feature_extraction_phi4_multimodal import *
from .image_processing_phi4_multimodal_fast import *
from .modeling_phi4_multimodal import *
from .processing_phi4_multimodal import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,482 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_phi4_multimodal.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from ...configuration_utils import PretrainedConfig
class Phi4MultimodalVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Phi4MultimodalVisionModel`]. It is used to instantiate a
Phi4Multimodal vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1152):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 4304):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 27):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 448):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
crop_size (`int`, *optional*, defaults to 448):
Crop size for the input images.
image_token_id (`int`, *optional*, defaults to 200010):
The image token id.
feature_layer (`int`, *optional*, defaults to -2):
The index of the layer of the encoder from which to extract image features.
Example:
```python
>>> from transformers import Phi4MultimodalVisionConfig
>>> # Initializing a Phi4MultimodalVisionConfig with microsoft/Phi-4-multimodal-instruct style configuration
>>> configuration = Phi4MultimodalVisionConfig()
```"""
model_type = "phi4_multimodal_vision"
base_config_key = "vision_config"
def __init__(
self,
hidden_size=1152,
intermediate_size=4304,
num_hidden_layers=27,
num_attention_heads=16,
num_channels=3,
image_size=448,
patch_size=14,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
crop_size: int = 448,
image_token_id: int = 200010,
feature_layer: int = -2,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.crop_size = crop_size
self.image_token_id = image_token_id
self.feature_layer = feature_layer
class Phi4MultimodalAudioConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Phi4MultimodalAudioModel`]. It is used to instantiate a
Phi4Multimodal audio encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the audio encoder of
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers.
intermediate_size (`int`, *optional*, defaults to 1536):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_blocks (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
activation (`str`, *optional*, defaults to `"swish"`):
The non-linear activation function in the MLPs.
chunk_size (`int`, *optional*, defaults to -1):
The chunk size to create the masks.
left_chunk (`int`, *optional*, defaults to 18):
The left chunk to create the masks.
dropout_rate (`float`, *optional*, defaults to 0.0):
The dropout ratio.
ext_pw_out_channel (`int`, *optional*, defaults to 1024):
Number of out channels in the point-wise conv modules.
depthwise_seperable_out_channel (`int`, *optional*, defaults to 1024):
Number of out channels in the depth-wise separable conv modules.
depthwise_multiplier (`int`, *optional*, defaults to 1):
Input size multiplier for the depth-wise separable conv modules.
kernel_size (`int`, *optional*, defaults to 3):
Kernel size for the depth-wise separable conv modules.
conv_activation (`str`, *optional*, defaults to `"swish"`):
The non-linear activation function in the conv modules.
input_size (`int`, *optional*, defaults to 80):
Input size for the audio model.
conv_glu_type (`str`, *optional*, defaults to `"swish"`):
The non-linear activation function in the point-wise conv modules.
time_reduction (`int`, *optional*, defaults to 8):
Time reduction (subsampling factor).
bias_max_distance (`int`, *optional*, defaults to 1000):
Max distance for the relative attention bias module.
bias_symmetric (`bool`, *optional*, defaults to `False`):
Whether the relative attention bias should be symmetric or not.
nemo_activation (`str`, *optional*, defaults to `"relu"`):
The non-linear activation function in the nemo conv modules.
nemo_conv_channels (`int`, *optional*, defaults to 1024):
Number of channels in the nemo conv modules.
downsample_rate (`int`, *optional*, defaults to 1):
Downsample rate for the audio feature extractor.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
audio_token_id (`int`, *optional*, defaults to 200011):
The audio token id.
feature_layer (`int`, *optional*, defaults to -2):
The index of the layer of the encoder from which to extract audio features.
Example:
```python
>>> from transformers import Phi4MultimodalAudioConfig
>>> # Initializing a Phi4MultimodalAudioConfig with microsoft/Phi-4-multimodal-instruct style configuration
>>> configuration = Phi4MultimodalAudioConfig()
```"""
model_type = "phi4_multimodal_audio"
def __init__(
self,
hidden_size: int = 1024,
intermediate_size: int = 1536,
num_blocks: int = 24,
num_attention_heads: int = 16,
activation: str = "swish",
chunk_size: int = -1,
left_chunk: int = 18,
dropout_rate: float = 0.0,
ext_pw_out_channel: int = 1024,
depthwise_seperable_out_channel: int = 1024,
depthwise_multiplier: int = 1,
kernel_size: int = 3,
conv_activation: str = "swish",
input_size: int = 80,
conv_glu_type: str = "swish",
time_reduction: int = 8,
bias_max_distance: int = 1000,
bias_symmetric: bool = False,
nemo_activation: str = "relu",
nemo_conv_channels: int = 1024,
downsample_rate: int = 1,
initializer_range: float = 0.02,
audio_token_id: int = 200011,
feature_layer: int = -2,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.activation = activation
self.chunk_size = chunk_size
self.left_chunk = left_chunk
self.num_blocks = num_blocks
self.dropout_rate = dropout_rate
self.ext_pw_out_channel = ext_pw_out_channel
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
self.depthwise_multiplier = depthwise_multiplier
self.kernel_size = kernel_size
self.conv_activation = conv_activation
self.input_size = input_size
self.conv_glu_type = conv_glu_type
self.time_reduction = time_reduction
self.bias_max_distance = bias_max_distance
self.bias_symmetric = bias_symmetric
self.nemo_activation = nemo_activation
self.nemo_conv_channels = nemo_conv_channels
self.downsample_rate = downsample_rate
self.audio_token_id = audio_token_id
self.initializer_range = initializer_range
self.feature_layer = feature_layer
if time_reduction % 2 != 0:
raise ValueError("`time_reduction` should be a multiple of 2!")
length = input_size
for _ in range(int(math.log(time_reduction, 2))):
length = math.floor((length - 1) / 2 + 1)
self.nemo_final_size = length
class Phi4MultimodalConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Phi4MultimodalModel`]. It is used to instantiate a
Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 200064):
Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Phi3Model`].
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
resid_pdrop (`float`, *optional*, defaults to 0.0):
Dropout probability for mlp outputs.
embd_pdrop (`int`, *optional*, defaults to 0.0):
The dropout ratio for the embeddings.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio after computing the attention scores.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon value used for the RMSNorm.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
divided by the number of attention heads divided by 2.
partial_rotary_factor (`float`, *optional*, defaults to `1.0`):
Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0.
bos_token_id (`int`, *optional*, defaults to 199999):
The id of the "beginning-of-sequence" token.
eos_token_id (`int` or `list[int]`, *optional*, defaults to `[199999, 200020]`):
The id of the "end-of-sequence" token.
pad_token_id (`int`, *optional*, defaults to 199999):
The id of the padding token.
original_max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model was trained with. This is used to determine the size of the
original RoPE embeddings when using long scaling.
sliding_window (`int`, *optional*):
Sliding window attention window size. If `None`, no sliding window is applied.
vision_config (`Phi4MultimodalVisionConfig` or `dict`, *optional*):
The vision config for the underlying image embedding model. If not provided, will default to the configuration
used to instantiate a model similar in architecture as
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct).
audio_config (`Phi4MultimodalAudioConfig` or `dict`, *optional*):
The audio config for the underlying audio embedding model. If not provided, will default to the configuration
used to instantiate a model similar in architecture as
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct).
Example:
```python
>>> from transformers import Phi4MultimodalModel, Phi4MultimodalConfig
>>> # Initializing a Phi4Multimodal style configuration
>>> configuration = Phi4MultimodalConfig.from_pretrained("microsoft/Phi-4-multimodal-instruct")
>>> # Initializing a model from the configuration
>>> model = Phi4MultimodalModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "phi4_multimodal"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
sub_configs = {"audio_config": Phi4MultimodalAudioConfig, "vision_config": Phi4MultimodalVisionConfig}
def __init__(
self,
vocab_size=200064,
hidden_size=3072,
intermediate_size=8192,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
resid_pdrop=0.0,
embd_pdrop=0.0,
attention_dropout=0.0,
hidden_act="silu",
max_position_embeddings=131072,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
partial_rotary_factor=1,
bos_token_id=199999,
eos_token_id=[199999, 200020],
pad_token_id=199999,
original_max_position_embeddings=4096,
sliding_window=None,
vision_config=None,
audio_config=None,
**kwargs,
):
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.partial_rotary_factor = partial_rotary_factor
self._rope_scaling_adjustment()
self._rope_scaling_validation()
self.sliding_window = sliding_window
if isinstance(vision_config, dict):
vision_config = Phi4MultimodalVisionConfig(**vision_config)
elif vision_config is None:
Phi4MultimodalVisionConfig()
self.vision_config = vision_config
if isinstance(audio_config, dict):
audio_config = Phi4MultimodalAudioConfig(**audio_config)
elif vision_config is None:
audio_config = Phi4MultimodalAudioConfig()
self.audio_config = audio_config
def _rope_scaling_adjustment(self):
"""
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
"""
if self.rope_scaling is None:
return
rope_scaling_type = self.rope_scaling.get("type", None)
# For backward compatibility if previous version used "su" or "yarn"
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
self.rope_scaling["type"] = "longrope"
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
raise ValueError(
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
if not len(rope_scaling_short_factor) == rotary_ndims // 2:
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if not len(rope_scaling_long_factor) == rotary_ndims // 2:
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
)
__all__ = ["Phi4MultimodalVisionConfig", "Phi4MultimodalAudioConfig", "Phi4MultimodalConfig"]

View File

@ -0,0 +1,229 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import torch
from peft import LoraConfig
from safetensors.torch import load_file, save_file
from transformers import (
Phi4MultimodalAudioConfig,
Phi4MultimodalConfig,
Phi4MultimodalForCausalLM,
Phi4MultimodalProcessor,
Phi4MultimodalVisionConfig,
)
# fmt: off
STATE_DICT_MAPPING = {
r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).feed_forward_(in|out).net.0.linear": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.feed_forward_\2.gate_up_proj",
r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).feed_forward_(in|out).net.2": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.feed_forward_\2.down_proj",
r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).self_attn.linear_(q|k|v)": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.self_attn.\2_proj",
r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).self_attn.linear_out": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.self_attn.o_proj",
r"^model.embed_tokens_extend.image_embed.img_projection.0": r"model.embed_tokens_extend.image_embed.img_projection_up",
r"^model.embed_tokens_extend.image_embed.img_projection.2": r"model.embed_tokens_extend.image_embed.img_projection_down",
r"^model.embed_tokens_extend.image_embed.glb_GN": r"model.embed_tokens_extend.image_embed.global_img_feature_extensor",
r"^model.embed_tokens_extend.image_embed.sub_GN": r"model.embed_tokens_extend.image_embed.sub_img_feature_extensor",
r"^model.embed_tokens_extend.audio_embed.audio_projection.speech.0": r"model.embed_tokens_extend.audio_embed.up_proj_for_speech",
r"^model.embed_tokens_extend.audio_embed.audio_projection.speech.2": r"model.embed_tokens_extend.audio_embed.down_proj_for_speech",
r"^model.embed_tokens_extend.audio_embed.audio_projection.vision.0": r"model.embed_tokens_extend.audio_embed.up_proj_for_vision_speech",
r"^model.embed_tokens_extend.audio_embed.audio_projection.vision.2": r"model.embed_tokens_extend.audio_embed.down_proj_for_vision_speech",
}
# fmt: on
def map_old_key_to_new(old_key):
"""Map of a key of the original state dict to the equivalent key in HF format"""
for pattern, replacement in STATE_DICT_MAPPING.items():
new_key, n_replace = re.subn(pattern, replacement, old_key)
# Early exit of the loop
if n_replace > 0:
return new_key
# The state dict contains lora keys....
if "lora" in old_key:
return None
# This extracts the original weight before adding the lora adapter
if "base_layer." in old_key:
return old_key.replace("base_layer.", "")
# not part of the key mapping, we keep the original name
return old_key
def convert_state_dict(original_state_dict: dict):
"""Convert a state dict file."""
new_dict = {}
for old_key, tensor in original_state_dict.items():
new_key = map_old_key_to_new(old_key)
if new_key is not None:
new_dict[new_key] = tensor
return new_dict
def convert_config(original_config: dict):
# Remove unused args
original_config.pop("_name_or_path", None)
original_config.pop("architectures", None)
original_config.pop("auto_map", None)
original_config.pop("vision_lora", None)
original_config.pop("speech_lora", None)
original_config.pop("transformers_version", None)
original_config.pop("_attn_implementation", None)
embd_layer = original_config.pop("embd_layer")
audio_embd_layer = embd_layer["audio_embd_layer"]
vision_embd_layer = embd_layer["image_embd_layer"]
# Keep only some of the subdict
keep_audio_embd_layer = ["downsample_rate"]
keep_vision_embd_layer = ["crop_size"]
audio_embd_layer = {k: v for k, v in audio_embd_layer.items() if k in keep_audio_embd_layer}
vision_embd_layer = {k: v for k, v in vision_embd_layer.items() if k in keep_vision_embd_layer}
audio_config = original_config.pop("audio_processor")["config"]
# remove
audio_config.pop("activation_checkpointing", None)
audio_config.pop("cnn_layer_norm", None)
audio_config.pop("input_layer", None)
audio_config.pop("batch_norm", None)
audio_config.pop("encoder_embedding_config", None)
audio_config.pop("ext_pw_kernel_size", None)
audio_config.pop("bias_in_glu", None)
audio_config.pop("causal", None)
# rename
audio_config["hidden_size"] = audio_config.pop("attention_dim")
audio_config["num_attention_heads"] = audio_config.pop("attention_heads")
audio_config["intermediate_size"] = audio_config.pop("linear_units")
audio_config["nemo_conv_channels"] = audio_config.pop("nemo_conv_settings")["conv_channels"]
audio_config["bias_max_distance"] = audio_config.pop("relative_attention_bias_args")["t5_bias_max_distance"]
# add
audio_config = {**audio_config, **audio_embd_layer}
# Create transformers config objects
audio_config = Phi4MultimodalAudioConfig(**audio_config)
vision_config = Phi4MultimodalVisionConfig(**vision_embd_layer)
# Add 2nd eos to config
original_config["eos_token_id"] = [199999, 200020]
new_config = Phi4MultimodalConfig(**original_config, vision_config=vision_config, audio_config=audio_config)
return new_config
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def convert_and_write_model(input_dir: str, output_dir: str):
"""Convert the model and save it (this implicitly save the config as well)."""
original_config = read_json(os.path.join(input_dir, "config.json"))
config = convert_config(original_config)
full_state_dict = {}
shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
for shard_file in shards:
original_state_dict = load_file(os.path.join(input_dir, shard_file))
new_dict = convert_state_dict(original_state_dict)
full_state_dict.update(new_dict)
# Load weights into model and resave them
with torch.device("meta"):
model = Phi4MultimodalForCausalLM(config)
missing, unexpected = model.load_state_dict(full_state_dict, strict=False, assign=True)
# The lm_head is missing because it's tied
if missing != ["lm_head.weight"]:
raise ValueError("Missing keys:\n{missing}")
if len(unexpected) > 0:
raise ValueError(f"Unexpected keys:\n{unexpected}")
model.tie_weights()
model.save_pretrained(output_dir)
def convert_and_save_processor(input_dir: str, output_dir: str):
"""Convert the processor."""
processor = Phi4MultimodalProcessor.from_pretrained(input_dir)
del processor.image_processor.auto_map
del processor.audio_processor.auto_map
processor.chat_template = processor.tokenizer.chat_template
processor.tokenizer.extra_special_tokens = {"image_token": "<|endoftext10|>", "audio_token": "<|endoftext11|>"}
processor.save_pretrained(output_dir)
def extract_adapters_data(input_dir: str, output_dir: str):
"""Extract adapters data from the state dict and save weights and configs."""
speech_lora = {}
vision_lora = {}
shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
for shard_file in shards:
original_state_dict = load_file(os.path.join(input_dir, shard_file))
for k, v in original_state_dict.items():
if "lora" in k:
if "speech" in k:
speech_lora[k.replace("speech.", "")] = v
elif "vision" in k:
vision_lora[k.replace("vision.", "")] = v
# Create and save the lora configs
speech_lora_config = LoraConfig(
r=320,
lora_alpha=640,
target_modules=r"model.layers.\d+.((self_attn.(qkv|o)_proj)|(mlp.(gate_up|down)_proj))",
lora_dropout=0.01,
task_type="CAUSAL_LM",
)
speech_lora_config.save_pretrained(os.path.join(output_dir, "speech-lora"))
vision_lora_config = LoraConfig(
r=256,
lora_alpha=512,
target_modules=r"model.layers.\d+.((self_attn.(qkv|o)_proj)|(mlp.(gate_up|down)_proj))",
lora_dropout=0.0,
task_type="CAUSAL_LM",
)
vision_lora_config.save_pretrained(os.path.join(output_dir, "vision-lora"))
save_file(speech_lora, os.path.join(output_dir, "speech-lora", "adapter_model.safetensors"))
save_file(vision_lora, os.path.join(output_dir, "vision-lora", "adapter_model.safetensors"))
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir",
help="Location of the model folder containing the weights and configs.",
)
parser.add_argument(
"output_dir",
help="Location to write HF model.",
)
args = parser.parse_args()
# Convert
convert_and_write_model(args.input_dir, args.output_dir)
convert_and_save_processor(args.input_dir, args.output_dir)
extract_adapters_data(args.input_dir, args.output_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,348 @@
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Phi4Multimodal
"""
from typing import Optional, Union
import numpy as np
from ...audio_utils import AudioInput
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...image_processing_utils import BatchFeature
from ...utils import TensorType, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
# TODO: @eustlb, remove this once #36603 is merged.
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
Args:
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
n_fft (int): FFT size. int > 0 [scalar]
n_mel (int): Mel filter size. int > 0 [scalar]
fmin (float): lowest frequency (in Hz). If None use 0.0.
float >= 0 [scalar]
fmax: highest frequency (in Hz). If None use sample_rate / 2.
float >= 0 [scalar]
Returns
out (numpy.ndarray): Mel transform matrix
[shape=(n_mels, 1 + n_fft/2)]
"""
bank_width = int(n_fft // 2 + 1)
if fmax is None:
fmax = sample_rate / 2
if fmin is None:
fmin = 0
assert fmin >= 0, "fmin cannot be negtive"
assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
def mel(f):
return 1127.0 * np.log(1.0 + f / 700.0)
def bin2mel(fft_bin):
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
def f2bin(f):
return int((f * n_fft / sample_rate) + 0.5)
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
klo = f2bin(fmin) + 1
khi = f2bin(fmax)
khi = max(khi, klo)
# Spec 2: SpeechLib uses trianges in Mel space
mlo = mel(fmin)
mhi = mel(fmax)
m_centers = np.linspace(mlo, mhi, n_mels + 2)
ms = (mhi - mlo) / (n_mels + 1)
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
for m in range(0, n_mels):
left = m_centers[m]
center = m_centers[m + 1]
right = m_centers[m + 2]
for fft_bin in range(klo, khi):
mbin = bin2mel(fft_bin)
if left < mbin < right:
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
return matrix
class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"]
def __init__(
self,
feature_size: int = 80,
sampling_rate: int = 16000,
hop_length: int = 160,
n_fft: int = 512,
win_length: int = 400,
preemphasis: float = 0.97,
padding_value: float = 0.0,
audio_compression_rate: int = 8,
audio_downsample_rate: int = 1,
audio_feat_stride: int = 1,
mel_min_frequency: float = 0,
mel_max_frequency: float = 7690,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.hop_length = hop_length
self.n_fft = n_fft
self.win_length = win_length
self.preemphasis = preemphasis
self.padding_value = padding_value
self.audio_compression_rate = audio_compression_rate
self.audio_downsample_rate = audio_downsample_rate
self.audio_feat_stride = audio_feat_stride
# TODO: @eustlb, uncomment and remove speechlib_mel once #36603 is merged.
# self.mel_filters = mel_filter_bank(
# num_frequency_bins=self.n_fft // 2 + 1,
# num_mel_filters=self.feature_size,
# min_frequency=mel_min_frequency,
# max_frequency=mel_max_frequency,
# sampling_rate=self.sampling_rate,
# triangularize_in_mel_space=True,
# mel_scale="kaldi",
# )
self.mel_filters = speechlib_mel(
self.sampling_rate, self.n_fft, self.feature_size, mel_min_frequency, mel_max_frequency
).T
def __call__(
self,
raw_speech: AudioInput,
sampling_rate: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
padding: Optional[str] = "longest",
max_length: Optional[int] = None,
truncation: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = True,
device: Optional[str] = "cpu",
**kwargs,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several audio sequence(s). Implementation uses PyTorch for
the STFT computation if available, otherwise a slower NumPy based one.
Args:
raw_speech (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The sequence or batch of sequences to be processed. Each sequence can be a numpy array or PyTorch tensor.
For batched inputs, sequences can be a list of numpy arrays or PyTorch tensors, or a single numpy array or
PyTorch tensor with first dimension being the batch size.
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
pad_to_multiple_of (`int`, *optional*, defaults to None):
If set will pad the sequence to a multiple of the provided value.
padding (`str`, *optional*, defaults to "longest"):
Padding strategy. Can be "longest" to pad to the longest sequence in the batch, or a specific length.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length.
truncation (`bool`, *optional*, defaults to False):
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of numpy arrays. Acceptable values are:
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
- `'tf'`: Return TensorFlow `tf.constant` objects.
return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether to return the extracted audio input features' attention mask.
device (`str`, *optional*, defaults to "cpu"):
Specifies the device for computation of the audio features. (e.g., "cpu", "cuda")
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **audio_input_features** -- Audio features extracted from the raw audio input, shape (batch_size, max_feature_length, feature_size).
- **audio_lengths** -- Length of each audio sample in the batch, shape (batch_size,).
- **audio_attention_mask** -- Attention mask for the audio input, shape (batch_size, max_feature_length).
If `return_tensors` is not specified, the fields will be PyTorch tensors if PyTorch is available, otherwise NumPy arrays.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
"Failing to do so can result in silent errors that might be hard to debug."
)
# Convert to torch tensor
if isinstance(raw_speech, np.ndarray):
raw_speech = torch.tensor(raw_speech)
elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray):
raw_speech = [torch.tensor(speech) for speech in raw_speech]
is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
if is_batched_torch and len(raw_speech.shape) > 2:
logger.warning(
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
"We will take the mean of the channels to convert to mono."
)
raw_speech = raw_speech.mean(-1)
is_batched_sequence = isinstance(raw_speech, (list, tuple))
if is_batched_sequence:
for speech in raw_speech:
if len(speech.shape) > 1:
logger.warning(
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
"We will take the mean of the channels to convert to mono."
)
speech = speech.mean(-1)
if is_batched_torch or is_batched_sequence:
raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
else:
raw_speech = [raw_speech[:, None].to(torch.float32)]
audio_lengths = [len(speech) for speech in raw_speech]
# convert into correct format for padding
batched_speech = BatchFeature(data={"audio_input_features": raw_speech, "audio_lengths": audio_lengths})
padded_inputs = self.pad(
batched_speech,
padding=padding,
max_length=max_length,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
)
input_features = padded_inputs.audio_input_features.squeeze(-1)
audio_lengths = padded_inputs.audio_lengths
input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device)
feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1
feature_lengths = feature_lengths * self.audio_feat_stride
audio_embed_sizes = self._compute_audio_embed_size(feature_lengths)
feature_attention_mask = (
torch.arange(0, feature_lengths.max()) if is_torch_available() else np.arange(0, feature_lengths.max())
)
feature_attention_mask = (
feature_attention_mask[None, :] < feature_lengths[:, None] if len(feature_lengths) > 1 else None
)
data = {
"audio_input_features": input_features,
"audio_embed_sizes": audio_embed_sizes,
}
if feature_attention_mask is not None and return_attention_mask:
data["audio_attention_mask"] = feature_attention_mask
return BatchFeature(data=data, tensor_type=return_tensors)
# TODO; @eustlb, move this to audio_utils in a general spectogram_batch function that handles torch and numpy
def _torch_extract_fbank_features(
self, waveform: "torch.FloatTensor", audio_lengths: "torch.Tensor", device: str = "cpu"
) -> "torch.FloatTensor":
"""
Compute the log mel-scaled spectrogram of batched waveforms using PyTorch's FFT implementation.
Args:
waveform (torch.FloatTensor` of shape `(batch_size, max_audio_length)`):
The batched waveforms.
audio_lengths (`torch.Tensor` of shape `(batch_size,)`):
The lengths of the waveforms along the max_audio_length dimension.
device (`str`, *optional*, defaults to "cpu"):
The device to run the computation on. (e.g., "cpu", "cuda")
Returns:
`torch.FloatTensor` of shape `(batch_size, max_feature_length, feature_size)`:
The log mel-scaled spectrogram of the batched waveforms.
"""
fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64)
# batched implementation
batch_size = waveform.shape[0]
frames = waveform.unfold(-1, self.win_length, self.hop_length)
# ---
# the unbatched (and unpaded) original implementation skips last few audio values that can't be included in a frame
# we need to ensure that the corresponding frames for the padded input also mask these values
if batch_size > 1:
frames = frames.clone()
# concerned batch indices
to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()]
if to_mask_batch_idxs.numel() > 0:
batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1
batch_idxs_up = audio_lengths[to_mask_batch_idxs] // self.hop_length + 1
offset_idx = batch_idxs_down.min()
max_idx = batch_idxs_up.max()
mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1)
mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & (
mask < (batch_idxs_up - offset_idx).unsqueeze(1)
)
mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length)
masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0)
frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames
# ---
# apply pre-emphasis first order filter on fft windows
frames_prev = torch.roll(frames, 1, dims=-1)
frames_prev[:, :, 0] = frames_prev[:, :, 1]
frames = (frames - self.preemphasis * frames_prev) * 32768
# apply fft
S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1)
S = S.view(frames.shape[0], -1, S.shape[-1])
S = S.to(torch.complex64)
spec = torch.abs(S)
spec_power = spec**2
# apply triangular mel filter bank
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
log_spec = torch.clamp(spec_power @ mel_filters, min=1.0)
log_spec = torch.log(log_spec)
return log_spec
def _compute_audio_embed_size(self, audio_frames):
integer = audio_frames // self.audio_compression_rate
remainder = audio_frames % self.audio_compression_rate
result = integer + (remainder > 0).to(integer.dtype)
integer = result // self.audio_downsample_rate
remainder = result % self.audio_downsample_rate
result = integer + (remainder > 0).to(integer.dtype) # qformer compression
return result
__all__ = ["Phi4MultimodalFeatureExtractor"]

View File

@ -0,0 +1,263 @@
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Phi4Multimodal
"""
import math
from typing import List, Optional, Union
import torch
from torchvision.transforms import functional as F
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
BatchFeature,
DefaultFastImageProcessorKwargs,
Unpack,
convert_to_rgb,
)
from ...image_utils import ImageInput, make_list_of_images, valid_images
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
class Phi4MultimodalFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
image_size: Optional[int]
patch_size: Optional[int]
dynamic_hd: Optional[int]
class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast):
r"""
Constructs a Phi4Multimodal image processor.
"""
image_size = 448
patch_size = 14
dynamic_hd = 36
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
valid_init_kwargs = Phi4MultimodalFastImageProcessorKwargs
model_input_names = ["image_pixel_values", "image_sizes", "image_attention_mask"]
def __init__(self, **kwargs: Unpack[Phi4MultimodalFastImageProcessorKwargs]):
super().__init__(**kwargs)
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * self.image_size * self.image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(self, image, max_num=36, min_num=1):
image_size = self.image_size
patch_size = self.patch_size
mask_size = image_size // patch_size
orig_width, orig_height = image.size
w_crop_num = math.ceil(orig_width / float(image_size))
h_crop_num = math.ceil(orig_height / float(image_size))
if w_crop_num * h_crop_num > max_num:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
else:
target_width = image_size * w_crop_num
target_height = image_size * h_crop_num
target_aspect_ratio = (w_crop_num, h_crop_num)
# Calculate the ratio
ratio_width = target_width / orig_width
ratio_height = target_height / orig_height
if ratio_width < ratio_height:
new_size = (target_width, int(orig_height * ratio_width))
padding_width = 0
padding_height = target_height - int(orig_height * ratio_width)
else:
new_size = (int(orig_width * ratio_height), target_height)
padding_width = target_width - int(orig_width * ratio_height)
padding_height = 0
attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), int(mask_size * target_aspect_ratio[0])))
if padding_width >= patch_size:
attention_mask[:, -math.floor(padding_width / patch_size) :] = 0
if padding_height >= patch_size:
attention_mask[-math.floor(padding_height / patch_size) :, :] = 0
if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10:
raise ValueError(f"the aspect ratio is very extreme {new_size}")
image = F.resize(image, [new_size[1], new_size[0]])
resized_img = F.pad(image, [0, 0, padding_width, padding_height], fill=[255, 255, 255])
return resized_img, attention_mask
def pad_to_max_num_crops(self, images, max_crops=5):
"""
images: B x 3 x H x W, B<=max_crops
"""
B, _, H, W = images.shape
if B < max_crops:
pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
images = torch.cat([images, pad], dim=0)
return images
def pad_mask_to_max_num_crops(self, masks, max_crops=5):
B, H, W = masks.shape
if B < max_crops:
pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device)
masks = torch.cat([masks, pad], dim=0)
return masks
def preprocess(
self,
images: ImageInput,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
):
"""
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
"""
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
images = [convert_to_rgb(image) for image in images]
image_size = self.image_size
patch_size = self.patch_size
mask_size = image_size // patch_size
imgs_and_masks = [self.dynamic_preprocess(image, max_num=self.dynamic_hd) for image in images]
images, image_attention_masks = [x[0] for x in imgs_and_masks], [x[1] for x in imgs_and_masks]
images = [F.to_tensor(image) for image in images]
hd_images = [F.normalize(image, image_mean, image_std) for image in images]
global_image = [
torch.nn.functional.interpolate(
image.unsqueeze(0).float(),
size=(image_size, image_size),
mode="bicubic",
).to(image.dtype)
for image in hd_images
]
shapes = [[image.size(1), image.size(2)] for image in hd_images]
mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks]
global_attention_mask = [torch.ones((1, mask_size, mask_size)) for _ in hd_images]
hd_images_reshape = []
for im, (h, w) in zip(hd_images, shapes):
im = im.reshape(1, 3, h // image_size, image_size, w // image_size, image_size)
im = im.permute(0, 2, 4, 1, 3, 5)
im = im.reshape(-1, 3, image_size, image_size)
hd_images_reshape.append(im.contiguous())
attention_masks_reshape = []
for mask, (h, w) in zip(image_attention_masks, mask_shapes):
mask = mask.reshape(h // mask_size, mask_size, w // mask_size, mask_size)
mask = mask.transpose(1, 2)
mask = mask.reshape(-1, mask_size, mask_size)
attention_masks_reshape.append(mask.contiguous())
downsample_attention_masks = []
for mask, (h, w) in zip(attention_masks_reshape, mask_shapes):
mask = mask[:, 0::2, 0::2]
mask = mask.reshape(
h // mask_size, w // mask_size, mask_size // 2 + mask_size % 2, mask_size // 2 + mask_size % 2
)
mask = mask.transpose(1, 2)
mask = mask.reshape(mask.size(0) * mask.size(1), mask.size(2) * mask.size(3))
downsample_attention_masks.append(mask)
num_img_tokens = [
256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 for mask in downsample_attention_masks
]
hd_images_reshape = [
torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)
]
hd_masks_reshape = [
torch.cat([_global_mask] + [_mask], dim=0)
for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape)
]
max_crops = max([img.size(0) for img in hd_images_reshape])
image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape]
image_transformed = torch.stack(image_transformed, dim=0)
mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape]
mask_transformed = torch.stack(mask_transformed, dim=0)
returned_input_image_embeds = image_transformed
returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
returned_image_attention_mask = mask_transformed
returned_num_img_tokens = num_img_tokens
data = {
"image_pixel_values": returned_input_image_embeds,
"image_sizes": returned_image_sizes,
"image_attention_mask": returned_image_attention_mask,
"num_img_tokens": returned_num_img_tokens,
}
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["Phi4MultimodalImageProcessorFast"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,194 @@
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Phi4Multimodal
"""
import re
from typing import List, Optional, Union
from ...audio_utils import AudioInput
from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import TextInput
from ...utils import logging
logger = logging.get_logger(__name__)
class Phi4MultimodalProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"audio_kwargs": {
"device": "cpu",
},
}
class Phi4MultimodalProcessor(ProcessorMixin):
r"""
Constructs a Phi4Multimodal processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor.
[`Phi4MultimodalProcessor`] offers all the functionalities of [`Phi4MultimodalImageProcessorFast`] and [`GPT2Tokenizer`]. See the
[`~Phi4MultimodalProcessor.__call__`] and [`~Phi4MultimodalProcessor.decode`] for more information.
Args:
image_processor (`Phi4MultimodalImageProcessorFast`):
The image processor to use for images.
audio_processor (`Phi4MultimodalFeatureExtractor`):
The audio processor to use for audio inputs.
tokenizer (`GPT2TokenizerFast`):
The tokenizer to use for text.
fake_image_token_pattern (`str`, *optional*, defaults to `r"<\|image_\d+\|>"`):
The fake image token pattern.
fake_audio_token_pattern (`str`, *optional*, defaults to `r"<\|audio_\d+\|>"`):
The fake audio token pattern.
"""
attributes = ["image_processor", "audio_processor", "tokenizer"]
tokenizer_class = "GPT2TokenizerFast"
image_processor_class = "Phi4MultimodalImageProcessorFast"
audio_processor_class = "Phi4MultimodalFeatureExtractor"
valid_kwargs = ["chat_template", "fake_image_token_pattern", "fake_audio_token_pattern"]
def __init__(
self,
image_processor,
audio_processor,
tokenizer,
fake_image_token_pattern: str = r"<\|image_\d+\|>",
fake_audio_token_pattern: str = r"<\|audio_\d+\|>",
**kwargs,
):
super().__init__(image_processor, audio_processor, tokenizer, **kwargs)
self.fake_image_token_pattern = fake_image_token_pattern
self.fake_audio_token_pattern = fake_audio_token_pattern
def __call__(
self,
text: Union[TextInput, List[TextInput]],
images: Optional[ImageInput] = None,
audios: Optional[AudioInput] = None,
**kwargs: Unpack[ProcessingKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text`
and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
Phi4MultimodalImageProcessorFast's [`~Phi4MultimodalImageProcessorFast.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
audios (`List[Union[np.ndarray, torch.Tensor]]`):
List of the audios to be prepared.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **input_image_embeds** -- Pixel values to be fed to a model.
- **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`.
- **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`.
- **input_audio_embeds** -- Audio embeddings to be fed to a model.
- **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`.
"""
output_kwargs = self._merge_kwargs(Phi4MultimodalProcessorKwargs, self.tokenizer.init_kwargs, **kwargs)
image_kwargs = output_kwargs["images_kwargs"]
audio_kwargs = output_kwargs["audio_kwargs"]
text_kwargs = output_kwargs["text_kwargs"]
image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
audio_inputs = self.audio_processor(audios, **audio_kwargs) if audios is not None else {}
# We pop here for images as we don't need it later
num_img_tokens = image_inputs.pop("num_img_tokens", [])
audio_embed_sizes = audio_inputs.get("audio_embed_sizes", [])
# Replace certain special tokens for compatibility
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
image_token = self.tokenizer.image_token
audio_token = self.tokenizer.audio_token
processed_text = [re.sub(self.fake_image_token_pattern, image_token, t) for t in text]
processed_text = [re.sub(self.fake_audio_token_pattern, audio_token, t) for t in processed_text]
# Check that the number of special tokens is sound
concatenated_prompt = "".join(processed_text)
if concatenated_prompt.count(self.tokenizer.image_token) != len(num_img_tokens):
raise ValueError(
"You should add as much image tokens `<|image_i|>` in your prompt as you pass `images` to the processor"
)
if concatenated_prompt.count(self.tokenizer.audio_token) != len(audio_embed_sizes):
raise ValueError(
"You should add as much audio tokens `<|audio_i|>` in your prompt as you pass `audios` to the processor"
)
# Add appropriate number of image/audio tokens (note that the count of replacement is dynamic)
image_count_iter = iter(num_img_tokens)
audio_count_iter = iter(audio_embed_sizes)
processed_text = [
re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in processed_text
]
processed_text = [
re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text
]
text_inputs = self.tokenizer(processed_text, **text_kwargs)
# prepare batch feature
data = {
**text_inputs,
**image_inputs,
**audio_inputs,
}
return BatchFeature(data=data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names))
__all__ = ["Phi4MultimodalProcessor"]

View File

@ -7746,6 +7746,55 @@ class Phi3PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Phi4MultimodalAudioModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi4MultimodalAudioPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi4MultimodalForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi4MultimodalModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi4MultimodalPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi4MultimodalVisionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Phi4MultimodalVisionPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PhimoeForCausalLM(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -93,6 +93,13 @@ class LlavaOnevisionImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
class Phi4MultimodalImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class PixtralImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

View File

View File

@ -0,0 +1,405 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import requests
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
GenerationConfig,
Phi4MultimodalAudioConfig,
Phi4MultimodalConfig,
Phi4MultimodalForCausalLM,
Phi4MultimodalModel,
Phi4MultimodalVisionConfig,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import (
require_soundfile,
require_torch,
slow,
torch_device,
)
from transformers.utils import is_soundfile_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
if is_soundfile_available():
import soundfile
class Phi4MultimodalModelTester:
def __init__(
self,
parent,
batch_size=2,
seq_length=12,
image_seq_length=275,
audio_seq_length=8,
is_training=True,
num_hidden_layers=2,
vocab_size=49,
hidden_size=32,
intermediate_size=64,
num_attention_heads=8,
num_key_value_heads=4,
bos_token_id=0,
eos_token_id=0,
pad_token_id=0,
image_token_id=1,
audio_token_id=2,
image_size=16,
audio_size=12,
audio_config=Phi4MultimodalAudioConfig(
num_blocks=2,
hidden_size=32,
num_attention_heads=8,
intermediate_size=48,
depthwise_seperable_out_channel=128,
nemo_conv_channels=128,
),
vision_config=Phi4MultimodalVisionConfig(
num_hidden_layers=2,
hidden_size=32,
intermediate_size=64,
num_attention_heads=8,
crop_size=16,
),
):
self.parent = parent
self.num_hidden_layers = num_hidden_layers
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.image_token_id = image_token_id
self.audio_token_id = audio_token_id
self.audio_config = audio_config
self.vision_config = vision_config
self.is_training = is_training
self.batch_size = batch_size
self.seq_length = seq_length + image_seq_length + audio_seq_length
self.image_seq_length = image_seq_length
self.audio_seq_length = audio_seq_length
self.image_size = image_size
self.audio_size = audio_size
self.num_channels = 3
def get_config(self):
return Phi4MultimodalConfig(
num_hidden_layers=self.num_hidden_layers,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
vision_config=self.vision_config,
audio_config=self.audio_config,
)
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
# The shapes corresponds to the inputs for image of size 16x16
image_pixel_values = floats_tensor([self.batch_size, 2, self.num_channels, self.image_size, self.image_size])
image_attention_mask = torch.ones(self.batch_size, 2, 1, 1)
image_sizes = torch.tensor(
[[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device
)
# Feature sizes returned by an audio of size 10000
audio_input_features = floats_tensor([self.batch_size, 61, 80])
audio_embed_sizes = torch.tensor([self.audio_seq_length] * self.batch_size, dtype=torch.long)
input_ids[input_ids == self.pad_token_id] = self.pad_token_id + 1 # random value but not pad token
input_ids[-1, 0] = self.pad_token_id # mask the last text token
input_ids[:, -self.image_seq_length - self.audio_seq_length : -self.audio_seq_length] = self.image_token_id
input_ids[:, -self.audio_seq_length :] = self.audio_token_id
attention_mask = torch.ones_like(input_ids)
attention_mask[-1, 0] = 0 # mask the last text token
config = self.get_config()
return (
config,
input_ids,
attention_mask,
image_pixel_values,
image_attention_mask,
image_sizes,
audio_input_features,
audio_embed_sizes,
)
def prepare_config_and_inputs_for_common(self):
(
config,
input_ids,
attention_mask,
image_pixel_values,
image_attention_mask,
image_sizes,
audio_input_features,
audio_embed_sizes,
) = self.prepare_config_and_inputs()
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"image_pixel_values": image_pixel_values,
"image_attention_mask": image_attention_mask,
"image_sizes": image_sizes,
"audio_input_features": audio_input_features,
"audio_embed_sizes": audio_embed_sizes,
}
return config, inputs_dict
def create_and_check_model(self, config, input_ids, attention_mask):
model = Phi4MultimodalForCausalLM(config=config)
model.to(torch_device)
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)["logits"]
self.parent.assertEqual(logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertFalse(torch.isnan(logits).any().item())
@require_torch
class Phi4MultimodalModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"""
Model tester for `Phi4Multimodal`.
"""
all_model_classes = (Phi4MultimodalForCausalLM, Phi4MultimodalModel) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = Phi4MultimodalModelTester(self)
self.config_tester = ConfigTester(self, config_class=Phi4MultimodalConfig)
@unittest.skip(reason="Unstable test")
def test_initialization(self):
pass
@unittest.skip(reason="Right padding not supported")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
@unittest.skip(reason="This one tries to use right padding as well")
def test_eager_matches_fa2_generate(self):
pass
@unittest.skip(reason="Depending on input modalities, some params may not have gradients")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="Depending on input modalities, some params may not have gradients")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="Depending on input modalities, some params may not have gradients")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Test tries to instantiate dynamic cache with an arg")
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="Test is only for old attention format")
def test_sdpa_can_dispatch_composite_models(self):
pass
@unittest.skip(reason="Static cache supported only for text-only inputs (not images or audios)")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(reason="Static cache supported only for text-only inputs (not images or audios)")
def test_generate_with_static_cache(self):
pass
@unittest.skip(
reason="Supported only for text-only inputs (otherwise dynamic control flows for multimodal inputs)"
)
def test_generate_compilation_all_outputs(self):
pass
@unittest.skip(
reason="Supported only for text-only inputs (otherwise dynamic control flows for multimodal inputs)"
)
def test_generate_compile_model_forward(self):
pass
@parameterized.expand([("random",), ("same",)])
@unittest.skip(reason="`image_attention_mask` has a specific shape")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip(reason="`image_attention_mask` has a specific shape")
def test_assisted_decoding_sample(self):
pass
@unittest.skip(reason="`image_attention_mask` has a specific shape")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip(reason="Cannot unpad inputs for all modalities so easily")
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip(reason="Dynamo error")
def test_flex_attention_with_grads(self):
pass
@require_torch
@slow
class Phi4MultimodalIntegrationTest(unittest.TestCase):
checkpoint_path = "microsoft/Phi-4-multimodal-instruct"
image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"
def setUp(self):
self.processor = AutoProcessor.from_pretrained(self.checkpoint_path)
self.generation_config = GenerationConfig(max_new_tokens=20, do_sample=False)
self.user_token = "<|user|>"
self.assistant_token = "<|assistant|>"
self.end_token = "<|end|>"
self.image = Image.open(requests.get(self.image_url, stream=True).raw)
with tempfile.NamedTemporaryFile(mode="w+b", suffix=".wav") as tmp:
tmp.write(requests.get(self.audio_url, stream=True).raw.data)
tmp.flush()
tmp.seek(0)
self.audio, self.sampling_rate = soundfile.read(tmp.name)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def test_text_only_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
)
prompt = f"{self.user_token}What is the answer for 1+1? Explain it.{self.end_token}{self.assistant_token}"
inputs = self.processor(prompt, images=None, return_tensors="pt").to(torch_device)
output = model.generate(
**inputs,
generation_config=self.generation_config,
)
output = output[:, inputs["input_ids"].shape[1] :]
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
EXPECTED_RESPONSE = "The answer for 1+1 is 2. This is because when you add one to another"
self.assertEqual(response, EXPECTED_RESPONSE)
def test_vision_text_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
)
prompt = f"{self.user_token}<|image_1|>What is shown in this image?{self.end_token}{self.assistant_token}"
inputs = self.processor(prompt, images=self.image, return_tensors="pt").to(torch_device)
output = model.generate(
**inputs,
generation_config=self.generation_config,
)
output = output[:, inputs["input_ids"].shape[1] :]
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
EXPECTED_RESPONSE = "The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural"
self.assertEqual(response, EXPECTED_RESPONSE)
def test_multi_image_vision_text_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
)
images = []
placeholder = ""
for i in range(1, 5):
url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg"
images.append(Image.open(requests.get(url, stream=True).raw))
placeholder += f"<|image_{i}|>"
prompt = f"{self.user_token}{placeholder}Summarize the deck of slides.{self.end_token}{self.assistant_token}"
inputs = self.processor(prompt, images, return_tensors="pt").to(torch_device)
output = model.generate(
**inputs,
generation_config=self.generation_config,
)
output = output[:, inputs["input_ids"].shape[1] :]
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
EXPECTED_RESPONSE = "The presentation provides an overview of Microsoft Azure, a cloud computing platform by Microsoft, and its various services"
self.assertEqual(response, EXPECTED_RESPONSE)
@require_soundfile
def test_audio_text_generation(self):
model = AutoModelForCausalLM.from_pretrained(
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
)
prompt = f"{self.user_token}<|audio_1|>What is happening in this audio?{self.end_token}{self.assistant_token}"
inputs = self.processor(prompt, audios=self.audio, sampling_rate=self.sampling_rate, return_tensors="pt").to(
torch_device
)
output = model.generate(
**inputs,
generation_config=self.generation_config,
)
output = output[:, inputs["input_ids"].shape[1] :]
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# Yes, it is truly the expected response... Even though the model correctly treats the audio file
EXPECTED_RESPONSE = "I'm sorry, but I can't listen to audio. However, if you describe the audio to me,"
self.assertEqual(response, EXPECTED_RESPONSE)

View File

@ -524,6 +524,7 @@ OBJECTS_TO_IGNORE = [
"TimeSeriesTransformerConfig",
"TokenClassificationPipeline",
"TrOCRConfig",
"Phi4MultimodalProcessor",
"TrainerState",
"TrainingArguments",
"TrajectoryTransformerConfig",

View File

@ -89,6 +89,8 @@ PRIVATE_MODELS = [
"SmolVLMVisionTransformer",
"AriaTextForCausalLM",
"AriaTextModel",
"Phi4MultimodalAudioModel",
"Phi4MultimodalVisionModel",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be.