mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Phi4 multimodal (#36939)
* raw start * update * update * add to imports * update * up * simplify configs * clean configs * style * typos * Update convert_phi4_multimodal_weights_to_hf.py * Update convert_phi4_multimodal_weights_to_hf.py * fix * up * up * up * Update convert_phi4_multimodal_weights_to_hf.py * Update convert_phi4_multimodal_weights_to_hf.py * up * up * up * Update feature_extraction_phi4_multimodal.py * up * up * up * up * up * simplify configs * typo * cut code * typo * typo * typo * re * typo * up * up * up * add tests * fix * fix * Update test_modeling_phi4_multimodal.py * up * Update test_modeling_phi4_multimodal.py * doc * fix * up * up * up * up * up * up * simplify * up * simplify * config docstrings * cleanup * clean * typo * typo * fix * Update phi4_multimodal.md * fix * fix * Update test_modeling_phi4_multimodal.py * update * simplify reshapes and permutes * up * simplify special tokens * simplify processor a lot * Update processing_phi4_multimodal.py * Update processing_phi4_multimodal.py * switch to fast processor * image processor * Update image_processing_phi4_multimodal_fast.py * add lora extraction to converter * Update convert_phi4_multimodal_weights_to_hf.py * Update __init__.py * add AudioInput type in audio_utils * rewrite feature_extraction: support torch batched FFT * input_audio_embeds -> audio_input_features, input_image_embeds -> image_pixel_values * test update * not mono channel warning update * remove auto maps from processor * kargs dispatch in processor * simplify kwargs dispatch * simplify merging * remove default sampling rate * style * Update test_modeling_phi4_multimodal.py * update doc * doc * torch only feature extractor * make fake tokens adjustable * Update feature_extraction_phi4_multimodal.py * fix * Update processing_phi4_multimodal.py * simplify mask * last touch * fix copies * style * Update audio_utils.py * style * Update feature_extraction_phi4_multimodal.py * Update __init__.py * docstrings * copies * fix all checks * back to fix-copies * trigger CIs * Update feature_extraction_phi4_multimodal.py * improve tests with multimodal inputs * trigger CIs --------- Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
This commit is contained in:
parent
47e5432805
commit
4303d88c09
@ -583,6 +583,8 @@
|
||||
title: Phi
|
||||
- local: model_doc/phi3
|
||||
title: Phi-3
|
||||
- local: model_doc/phi4_multimodal
|
||||
title: Phi4 Multimodal
|
||||
- local: model_doc/phimoe
|
||||
title: PhiMoE
|
||||
- local: model_doc/phobert
|
||||
|
149
docs/source/en/model_doc/phi4_multimodal.md
Normal file
149
docs/source/en/model_doc/phi4_multimodal.md
Normal file
@ -0,0 +1,149 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# Phi4 Multimodal
|
||||
|
||||
## Overview
|
||||
|
||||
Phi4 Multimodal is a lightweight open multimodal foundation model that leverages the language, vision, and speech research and datasets used for Phi-3.5 and 4.0 models. The model processes text, image, and audio inputs, generating text outputs, and comes with 128K token context length. The model underwent an enhancement process, incorporating both supervised fine-tuning, direct preference optimization and RLHF (Reinforcement Learning from Human Feedback) to support precise instruction adherence and safety measures. The languages that each modal supports are the following:
|
||||
|
||||
- Text: Arabic, Chinese, Czech, Danish, Dutch, English, Finnish, French, German, Hebrew, Hungarian, Italian, Japanese, Korean, Norwegian, Polish, Portuguese, Russian, Spanish, Swedish, Thai, Turkish, Ukrainian
|
||||
- Vision: English
|
||||
- Audio: English, Chinese, German, French, Italian, Japanese, Spanish, Portuguese
|
||||
|
||||
This model was contributed by [Cyril Vallez](https://huggingface.co/cyrilvallez). The most recent code can be
|
||||
found [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py).
|
||||
|
||||
|
||||
## Usage tips
|
||||
|
||||
`Phi4-multimodal-instruct` can be found on the [Huggingface Hub](https://huggingface.co/microsoft/Phi-4-multimodal-instruct)
|
||||
|
||||
In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
import os
|
||||
import io
|
||||
from PIL import Image
|
||||
import soundfile as sf
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
||||
from urllib.request import urlopen
|
||||
|
||||
|
||||
# Define model path
|
||||
model_path = "microsoft/Phi-4-multimodal-instruct"
|
||||
device = "cuda:0"
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, torch_dtype=torch.float16)
|
||||
|
||||
# Optional: load the adapters (note that without them, the base model will very likely not work well)
|
||||
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
|
||||
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})
|
||||
|
||||
# Define prompt structure
|
||||
user_prompt = '<|user|>'
|
||||
assistant_prompt = '<|assistant|>'
|
||||
prompt_suffix = '<|end|>'
|
||||
|
||||
# Part 1: Image Processing
|
||||
model.set_adapter("vision") # if loaded, activate the vision adapter
|
||||
print("\n--- IMAGE PROCESSING ---")
|
||||
image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
|
||||
prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}'
|
||||
print(f'>>> Prompt\n{prompt}')
|
||||
|
||||
# Download and open image
|
||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
inputs = processor(text=prompt, images=image, return_tensors='pt').to(device)
|
||||
|
||||
# Generate response
|
||||
generate_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
|
||||
response = processor.batch_decode(
|
||||
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
print(f'>>> Response\n{response}')
|
||||
|
||||
# Part 2: Audio Processing
|
||||
model.set_adapter("speech") # if loaded, activate the speech adapter
|
||||
print("\n--- AUDIO PROCESSING ---")
|
||||
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
|
||||
speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation."
|
||||
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}'
|
||||
print(f'>>> Prompt\n{prompt}')
|
||||
|
||||
# Downlowd and open audio file
|
||||
audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read()))
|
||||
|
||||
# Process with the model
|
||||
inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device)
|
||||
|
||||
generate_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=1000,
|
||||
do_sample=False,
|
||||
)
|
||||
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
|
||||
response = processor.batch_decode(
|
||||
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
print(f'>>> Response\n{response}')
|
||||
```
|
||||
|
||||
## Phi4MultimodalFeatureExtractor
|
||||
|
||||
[[autodoc]] Phi4MultimodalFeatureExtractor
|
||||
|
||||
## Phi4MultimodalImageProcessorFast
|
||||
|
||||
[[autodoc]] Phi4MultimodalImageProcessorFast
|
||||
|
||||
## Phi4MultimodalProcessor
|
||||
|
||||
[[autodoc]] Phi4MultimodalProcessor
|
||||
|
||||
## Phi4MultimodalAudioConfig
|
||||
|
||||
[[autodoc]] Phi4MultimodalAudioConfig
|
||||
|
||||
## Phi4MultimodalVisionConfig
|
||||
|
||||
[[autodoc]] Phi4MultimodalVisionConfig
|
||||
|
||||
## Phi4MultimodalConfig
|
||||
|
||||
[[autodoc]] Phi4MultimodalConfig
|
||||
|
||||
## Phi4MultimodalAudioModel
|
||||
|
||||
[[autodoc]] Phi4MultimodalAudioModel
|
||||
|
||||
## Phi4MultimodalVisionModel
|
||||
|
||||
[[autodoc]] Phi4MultimodalVisionModel
|
||||
|
||||
## Phi4MultimodalModel
|
||||
|
||||
[[autodoc]] Phi4MultimodalModel
|
||||
- forward
|
||||
|
||||
## Phi4MultimodalForCausalLM
|
||||
|
||||
[[autodoc]] Phi4MultimodalForCausalLM
|
||||
- forward
|
@ -699,6 +699,13 @@ _import_structure = {
|
||||
"models.persimmon": ["PersimmonConfig"],
|
||||
"models.phi": ["PhiConfig"],
|
||||
"models.phi3": ["Phi3Config"],
|
||||
"models.phi4_multimodal": [
|
||||
"Phi4MultimodalAudioConfig",
|
||||
"Phi4MultimodalConfig",
|
||||
"Phi4MultimodalFeatureExtractor",
|
||||
"Phi4MultimodalProcessor",
|
||||
"Phi4MultimodalVisionConfig",
|
||||
],
|
||||
"models.phimoe": ["PhimoeConfig"],
|
||||
"models.phobert": ["PhobertTokenizer"],
|
||||
"models.pix2struct": [
|
||||
@ -1348,6 +1355,7 @@ else:
|
||||
_import_structure["models.llava"].append("LlavaImageProcessorFast")
|
||||
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
|
||||
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
|
||||
_import_structure["models.phi4_multimodal"].append("Phi4MultimodalImageProcessorFast")
|
||||
_import_structure["models.pixtral"].append("PixtralImageProcessorFast")
|
||||
_import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast")
|
||||
_import_structure["models.rt_detr"].append("RTDetrImageProcessorFast")
|
||||
@ -2802,6 +2810,17 @@ else:
|
||||
"LlavaNextPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.phi4_multimodal"].extend(
|
||||
[
|
||||
"Phi4MultimodalForCausalLM",
|
||||
"Phi4MultimodalPreTrainedModel",
|
||||
"Phi4MultimodalAudioModel",
|
||||
"Phi4MultimodalAudioPreTrainedModel",
|
||||
"Phi4MultimodalModel",
|
||||
"Phi4MultimodalVisionModel",
|
||||
"Phi4MultimodalVisionPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.llava_next_video"].extend(
|
||||
[
|
||||
"LlavaNextVideoForConditionalGeneration",
|
||||
@ -5914,6 +5933,13 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.phi import PhiConfig
|
||||
from .models.phi3 import Phi3Config
|
||||
from .models.phi4_multimodal import (
|
||||
Phi4MultimodalAudioConfig,
|
||||
Phi4MultimodalConfig,
|
||||
Phi4MultimodalFeatureExtractor,
|
||||
Phi4MultimodalProcessor,
|
||||
Phi4MultimodalVisionConfig,
|
||||
)
|
||||
from .models.phimoe import PhimoeConfig
|
||||
from .models.phobert import PhobertTokenizer
|
||||
from .models.pix2struct import (
|
||||
@ -6587,6 +6613,7 @@ if TYPE_CHECKING:
|
||||
from .models.llava import LlavaImageProcessorFast
|
||||
from .models.llava_next import LlavaNextImageProcessorFast
|
||||
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
|
||||
from .models.phi4_multimodal import Phi4MultimodalImageProcessorFast
|
||||
from .models.pixtral import PixtralImageProcessorFast
|
||||
from .models.qwen2_vl import Qwen2VLImageProcessorFast
|
||||
from .models.rt_detr import RTDetrImageProcessorFast
|
||||
@ -8153,6 +8180,15 @@ if TYPE_CHECKING:
|
||||
Phi3Model,
|
||||
Phi3PreTrainedModel,
|
||||
)
|
||||
from .models.phi4_multimodal import (
|
||||
Phi4MultimodalAudioModel,
|
||||
Phi4MultimodalAudioPreTrainedModel,
|
||||
Phi4MultimodalForCausalLM,
|
||||
Phi4MultimodalModel,
|
||||
Phi4MultimodalPreTrainedModel,
|
||||
Phi4MultimodalVisionModel,
|
||||
Phi4MultimodalVisionPreTrainedModel,
|
||||
)
|
||||
from .models.phimoe import (
|
||||
PhimoeForCausalLM,
|
||||
PhimoeForSequenceClassification,
|
||||
|
@ -17,11 +17,16 @@ and remove unnecessary dependencies.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
AudioInput = Union[
|
||||
np.ndarray, "torch.Tensor", List[np.ndarray], Tuple[np.ndarray], List["torch.Tensor"], Tuple["torch.Tensor"] # noqa: F821
|
||||
]
|
||||
|
||||
|
||||
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
|
||||
"""
|
||||
Convert frequency from hertz to mels.
|
||||
|
@ -212,6 +212,7 @@ from . import (
|
||||
persimmon,
|
||||
phi,
|
||||
phi3,
|
||||
phi4_multimodal,
|
||||
phimoe,
|
||||
phobert,
|
||||
pix2struct,
|
||||
|
@ -235,6 +235,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("persimmon", "PersimmonConfig"),
|
||||
("phi", "PhiConfig"),
|
||||
("phi3", "Phi3Config"),
|
||||
("phi4_multimodal", "Phi4MultimodalConfig"),
|
||||
("phimoe", "PhimoeConfig"),
|
||||
("pix2struct", "Pix2StructConfig"),
|
||||
("pixtral", "PixtralVisionConfig"),
|
||||
@ -587,6 +588,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("persimmon", "Persimmon"),
|
||||
("phi", "Phi"),
|
||||
("phi3", "Phi3"),
|
||||
("phi4_multimodal", "Phi4Multimodal"),
|
||||
("phimoe", "Phimoe"),
|
||||
("phobert", "PhoBERT"),
|
||||
("pix2struct", "Pix2Struct"),
|
||||
|
@ -78,6 +78,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("nat", "ViTFeatureExtractor"),
|
||||
("owlvit", "OwlViTFeatureExtractor"),
|
||||
("perceiver", "PerceiverFeatureExtractor"),
|
||||
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
|
||||
("poolformer", "PoolFormerFeatureExtractor"),
|
||||
("pop2piano", "Pop2PianoFeatureExtractor"),
|
||||
("regnet", "ConvNextFeatureExtractor"),
|
||||
|
@ -124,6 +124,7 @@ else:
|
||||
("owlvit", ("OwlViTImageProcessor",)),
|
||||
("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
|
||||
("perceiver", ("PerceiverImageProcessor",)),
|
||||
("phi4_multimodal", "Phi4MultimodalImageProcessorFast"),
|
||||
("pix2struct", ("Pix2StructImageProcessor",)),
|
||||
("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||
("poolformer", ("PoolFormerImageProcessor",)),
|
||||
|
@ -218,6 +218,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("persimmon", "PersimmonModel"),
|
||||
("phi", "PhiModel"),
|
||||
("phi3", "Phi3Model"),
|
||||
("phi4_multimodal", "Phi4MultimodalModel"),
|
||||
("phimoe", "PhimoeModel"),
|
||||
("pixtral", "PixtralVisionModel"),
|
||||
("plbart", "PLBartModel"),
|
||||
@ -566,6 +567,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("persimmon", "PersimmonForCausalLM"),
|
||||
("phi", "PhiForCausalLM"),
|
||||
("phi3", "Phi3ForCausalLM"),
|
||||
("phi4_multimodal", "Phi4MultimodalForCausalLM"),
|
||||
("phimoe", "PhimoeForCausalLM"),
|
||||
("plbart", "PLBartForCausalLM"),
|
||||
("prophetnet", "ProphetNetForCausalLM"),
|
||||
|
@ -91,6 +91,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("owlv2", "Owlv2Processor"),
|
||||
("owlvit", "OwlViTProcessor"),
|
||||
("paligemma", "PaliGemmaProcessor"),
|
||||
("phi4_multimodal", "Phi4MultimodalProcessor"),
|
||||
("pix2struct", "Pix2StructProcessor"),
|
||||
("pixtral", "PixtralProcessor"),
|
||||
("pop2piano", "Pop2PianoProcessor"),
|
||||
|
32
src/transformers/models/phi4_multimodal/__init__.py
Normal file
32
src/transformers/models/phi4_multimodal/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_phi4_multimodal import *
|
||||
from .feature_extraction_phi4_multimodal import *
|
||||
from .image_processing_phi4_multimodal_fast import *
|
||||
from .modeling_phi4_multimodal import *
|
||||
from .processing_phi4_multimodal import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
@ -0,0 +1,482 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_phi4_multimodal.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class Phi4MultimodalVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Phi4MultimodalVisionModel`]. It is used to instantiate a
|
||||
Phi4Multimodal vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the vision encoder of
|
||||
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1152):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 4304):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 27):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input images.
|
||||
image_size (`int`, *optional*, defaults to 448):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
crop_size (`int`, *optional*, defaults to 448):
|
||||
Crop size for the input images.
|
||||
image_token_id (`int`, *optional*, defaults to 200010):
|
||||
The image token id.
|
||||
feature_layer (`int`, *optional*, defaults to -2):
|
||||
The index of the layer of the encoder from which to extract image features.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Phi4MultimodalVisionConfig
|
||||
|
||||
>>> # Initializing a Phi4MultimodalVisionConfig with microsoft/Phi-4-multimodal-instruct style configuration
|
||||
>>> configuration = Phi4MultimodalVisionConfig()
|
||||
```"""
|
||||
|
||||
model_type = "phi4_multimodal_vision"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1152,
|
||||
intermediate_size=4304,
|
||||
num_hidden_layers=27,
|
||||
num_attention_heads=16,
|
||||
num_channels=3,
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
crop_size: int = 448,
|
||||
image_token_id: int = 200010,
|
||||
feature_layer: int = -2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.crop_size = crop_size
|
||||
self.image_token_id = image_token_id
|
||||
self.feature_layer = feature_layer
|
||||
|
||||
|
||||
class Phi4MultimodalAudioConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Phi4MultimodalAudioModel`]. It is used to instantiate a
|
||||
Phi4Multimodal audio encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the audio encoder of
|
||||
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the encoder layers.
|
||||
intermediate_size (`int`, *optional*, defaults to 1536):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_blocks (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
activation (`str`, *optional*, defaults to `"swish"`):
|
||||
The non-linear activation function in the MLPs.
|
||||
chunk_size (`int`, *optional*, defaults to -1):
|
||||
The chunk size to create the masks.
|
||||
left_chunk (`int`, *optional*, defaults to 18):
|
||||
The left chunk to create the masks.
|
||||
dropout_rate (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio.
|
||||
ext_pw_out_channel (`int`, *optional*, defaults to 1024):
|
||||
Number of out channels in the point-wise conv modules.
|
||||
depthwise_seperable_out_channel (`int`, *optional*, defaults to 1024):
|
||||
Number of out channels in the depth-wise separable conv modules.
|
||||
depthwise_multiplier (`int`, *optional*, defaults to 1):
|
||||
Input size multiplier for the depth-wise separable conv modules.
|
||||
kernel_size (`int`, *optional*, defaults to 3):
|
||||
Kernel size for the depth-wise separable conv modules.
|
||||
conv_activation (`str`, *optional*, defaults to `"swish"`):
|
||||
The non-linear activation function in the conv modules.
|
||||
input_size (`int`, *optional*, defaults to 80):
|
||||
Input size for the audio model.
|
||||
conv_glu_type (`str`, *optional*, defaults to `"swish"`):
|
||||
The non-linear activation function in the point-wise conv modules.
|
||||
time_reduction (`int`, *optional*, defaults to 8):
|
||||
Time reduction (subsampling factor).
|
||||
bias_max_distance (`int`, *optional*, defaults to 1000):
|
||||
Max distance for the relative attention bias module.
|
||||
bias_symmetric (`bool`, *optional*, defaults to `False`):
|
||||
Whether the relative attention bias should be symmetric or not.
|
||||
nemo_activation (`str`, *optional*, defaults to `"relu"`):
|
||||
The non-linear activation function in the nemo conv modules.
|
||||
nemo_conv_channels (`int`, *optional*, defaults to 1024):
|
||||
Number of channels in the nemo conv modules.
|
||||
downsample_rate (`int`, *optional*, defaults to 1):
|
||||
Downsample rate for the audio feature extractor.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
audio_token_id (`int`, *optional*, defaults to 200011):
|
||||
The audio token id.
|
||||
feature_layer (`int`, *optional*, defaults to -2):
|
||||
The index of the layer of the encoder from which to extract audio features.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Phi4MultimodalAudioConfig
|
||||
|
||||
>>> # Initializing a Phi4MultimodalAudioConfig with microsoft/Phi-4-multimodal-instruct style configuration
|
||||
>>> configuration = Phi4MultimodalAudioConfig()
|
||||
```"""
|
||||
|
||||
model_type = "phi4_multimodal_audio"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 1024,
|
||||
intermediate_size: int = 1536,
|
||||
num_blocks: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
activation: str = "swish",
|
||||
chunk_size: int = -1,
|
||||
left_chunk: int = 18,
|
||||
dropout_rate: float = 0.0,
|
||||
ext_pw_out_channel: int = 1024,
|
||||
depthwise_seperable_out_channel: int = 1024,
|
||||
depthwise_multiplier: int = 1,
|
||||
kernel_size: int = 3,
|
||||
conv_activation: str = "swish",
|
||||
input_size: int = 80,
|
||||
conv_glu_type: str = "swish",
|
||||
time_reduction: int = 8,
|
||||
bias_max_distance: int = 1000,
|
||||
bias_symmetric: bool = False,
|
||||
nemo_activation: str = "relu",
|
||||
nemo_conv_channels: int = 1024,
|
||||
downsample_rate: int = 1,
|
||||
initializer_range: float = 0.02,
|
||||
audio_token_id: int = 200011,
|
||||
feature_layer: int = -2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.activation = activation
|
||||
self.chunk_size = chunk_size
|
||||
self.left_chunk = left_chunk
|
||||
self.num_blocks = num_blocks
|
||||
self.dropout_rate = dropout_rate
|
||||
self.ext_pw_out_channel = ext_pw_out_channel
|
||||
self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
|
||||
self.depthwise_multiplier = depthwise_multiplier
|
||||
self.kernel_size = kernel_size
|
||||
self.conv_activation = conv_activation
|
||||
self.input_size = input_size
|
||||
self.conv_glu_type = conv_glu_type
|
||||
self.time_reduction = time_reduction
|
||||
self.bias_max_distance = bias_max_distance
|
||||
self.bias_symmetric = bias_symmetric
|
||||
self.nemo_activation = nemo_activation
|
||||
self.nemo_conv_channels = nemo_conv_channels
|
||||
self.downsample_rate = downsample_rate
|
||||
self.audio_token_id = audio_token_id
|
||||
self.initializer_range = initializer_range
|
||||
self.feature_layer = feature_layer
|
||||
|
||||
if time_reduction % 2 != 0:
|
||||
raise ValueError("`time_reduction` should be a multiple of 2!")
|
||||
length = input_size
|
||||
for _ in range(int(math.log(time_reduction, 2))):
|
||||
length = math.floor((length - 1) / 2 + 1)
|
||||
self.nemo_final_size = length
|
||||
|
||||
|
||||
class Phi4MultimodalConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Phi4MultimodalModel`]. It is used to instantiate a
|
||||
Phi4Multimodal model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the
|
||||
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 200064):
|
||||
Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Phi3Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability for mlp outputs.
|
||||
embd_pdrop (`int`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the embeddings.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio after computing the attention scores.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon value used for the RMSNorm.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`dict`, *optional*):
|
||||
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
|
||||
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
|
||||
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
|
||||
divided by the number of attention heads divided by 2.
|
||||
partial_rotary_factor (`float`, *optional*, defaults to `1.0`):
|
||||
Percentage of the query and keys which will have rotary embedding. Must be between 0.0 and 1.0.
|
||||
bos_token_id (`int`, *optional*, defaults to 199999):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int` or `list[int]`, *optional*, defaults to `[199999, 200020]`):
|
||||
The id of the "end-of-sequence" token.
|
||||
pad_token_id (`int`, *optional*, defaults to 199999):
|
||||
The id of the padding token.
|
||||
original_max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||
The maximum sequence length that this model was trained with. This is used to determine the size of the
|
||||
original RoPE embeddings when using long scaling.
|
||||
sliding_window (`int`, *optional*):
|
||||
Sliding window attention window size. If `None`, no sliding window is applied.
|
||||
vision_config (`Phi4MultimodalVisionConfig` or `dict`, *optional*):
|
||||
The vision config for the underlying image embedding model. If not provided, will default to the configuration
|
||||
used to instantiate a model similar in architecture as
|
||||
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct).
|
||||
audio_config (`Phi4MultimodalAudioConfig` or `dict`, *optional*):
|
||||
The audio config for the underlying audio embedding model. If not provided, will default to the configuration
|
||||
used to instantiate a model similar in architecture as
|
||||
[microsoft/Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Phi4MultimodalModel, Phi4MultimodalConfig
|
||||
|
||||
>>> # Initializing a Phi4Multimodal style configuration
|
||||
>>> configuration = Phi4MultimodalConfig.from_pretrained("microsoft/Phi-4-multimodal-instruct")
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = Phi4MultimodalModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "phi4_multimodal"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.qkv_proj": "colwise_rep", # we need to replicate here due to the slicing of qkv
|
||||
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the slicing of qkv
|
||||
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
sub_configs = {"audio_config": Phi4MultimodalAudioConfig, "vision_config": Phi4MultimodalVisionConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=200064,
|
||||
hidden_size=3072,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attention_dropout=0.0,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=131072,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
partial_rotary_factor=1,
|
||||
bos_token_id=199999,
|
||||
eos_token_id=[199999, 200020],
|
||||
pad_token_id=199999,
|
||||
original_max_position_embeddings=4096,
|
||||
sliding_window=None,
|
||||
vision_config=None,
|
||||
audio_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_act = hidden_act
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self._rope_scaling_adjustment()
|
||||
self._rope_scaling_validation()
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
vision_config = Phi4MultimodalVisionConfig(**vision_config)
|
||||
elif vision_config is None:
|
||||
Phi4MultimodalVisionConfig()
|
||||
self.vision_config = vision_config
|
||||
|
||||
if isinstance(audio_config, dict):
|
||||
audio_config = Phi4MultimodalAudioConfig(**audio_config)
|
||||
elif vision_config is None:
|
||||
audio_config = Phi4MultimodalAudioConfig()
|
||||
self.audio_config = audio_config
|
||||
|
||||
def _rope_scaling_adjustment(self):
|
||||
"""
|
||||
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
|
||||
# For backward compatibility if previous version used "su" or "yarn"
|
||||
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
|
||||
self.rope_scaling["type"] = "longrope"
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
|
||||
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
|
||||
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
|
||||
if not (
|
||||
isinstance(rope_scaling_short_factor, list)
|
||||
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
|
||||
)
|
||||
rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor)
|
||||
if not len(rope_scaling_short_factor) == rotary_ndims // 2:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
|
||||
)
|
||||
if not (
|
||||
isinstance(rope_scaling_long_factor, list)
|
||||
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
|
||||
)
|
||||
if not len(rope_scaling_long_factor) == rotary_ndims // 2:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Phi4MultimodalVisionConfig", "Phi4MultimodalAudioConfig", "Phi4MultimodalConfig"]
|
@ -0,0 +1,229 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from transformers import (
|
||||
Phi4MultimodalAudioConfig,
|
||||
Phi4MultimodalConfig,
|
||||
Phi4MultimodalForCausalLM,
|
||||
Phi4MultimodalProcessor,
|
||||
Phi4MultimodalVisionConfig,
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
STATE_DICT_MAPPING = {
|
||||
r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).feed_forward_(in|out).net.0.linear": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.feed_forward_\2.gate_up_proj",
|
||||
r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).feed_forward_(in|out).net.2": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.feed_forward_\2.down_proj",
|
||||
|
||||
r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).self_attn.linear_(q|k|v)": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.self_attn.\2_proj",
|
||||
r"^model.embed_tokens_extend.audio_embed.encoder.encoders.(\d+).self_attn.linear_out": r"model.embed_tokens_extend.audio_embed.encoder.encoders.\1.self_attn.o_proj",
|
||||
|
||||
r"^model.embed_tokens_extend.image_embed.img_projection.0": r"model.embed_tokens_extend.image_embed.img_projection_up",
|
||||
r"^model.embed_tokens_extend.image_embed.img_projection.2": r"model.embed_tokens_extend.image_embed.img_projection_down",
|
||||
|
||||
r"^model.embed_tokens_extend.image_embed.glb_GN": r"model.embed_tokens_extend.image_embed.global_img_feature_extensor",
|
||||
r"^model.embed_tokens_extend.image_embed.sub_GN": r"model.embed_tokens_extend.image_embed.sub_img_feature_extensor",
|
||||
|
||||
r"^model.embed_tokens_extend.audio_embed.audio_projection.speech.0": r"model.embed_tokens_extend.audio_embed.up_proj_for_speech",
|
||||
r"^model.embed_tokens_extend.audio_embed.audio_projection.speech.2": r"model.embed_tokens_extend.audio_embed.down_proj_for_speech",
|
||||
r"^model.embed_tokens_extend.audio_embed.audio_projection.vision.0": r"model.embed_tokens_extend.audio_embed.up_proj_for_vision_speech",
|
||||
r"^model.embed_tokens_extend.audio_embed.audio_projection.vision.2": r"model.embed_tokens_extend.audio_embed.down_proj_for_vision_speech",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
def map_old_key_to_new(old_key):
|
||||
"""Map of a key of the original state dict to the equivalent key in HF format"""
|
||||
for pattern, replacement in STATE_DICT_MAPPING.items():
|
||||
new_key, n_replace = re.subn(pattern, replacement, old_key)
|
||||
# Early exit of the loop
|
||||
if n_replace > 0:
|
||||
return new_key
|
||||
|
||||
# The state dict contains lora keys....
|
||||
if "lora" in old_key:
|
||||
return None
|
||||
# This extracts the original weight before adding the lora adapter
|
||||
if "base_layer." in old_key:
|
||||
return old_key.replace("base_layer.", "")
|
||||
|
||||
# not part of the key mapping, we keep the original name
|
||||
return old_key
|
||||
|
||||
|
||||
def convert_state_dict(original_state_dict: dict):
|
||||
"""Convert a state dict file."""
|
||||
new_dict = {}
|
||||
for old_key, tensor in original_state_dict.items():
|
||||
new_key = map_old_key_to_new(old_key)
|
||||
if new_key is not None:
|
||||
new_dict[new_key] = tensor
|
||||
return new_dict
|
||||
|
||||
|
||||
def convert_config(original_config: dict):
|
||||
# Remove unused args
|
||||
original_config.pop("_name_or_path", None)
|
||||
original_config.pop("architectures", None)
|
||||
original_config.pop("auto_map", None)
|
||||
original_config.pop("vision_lora", None)
|
||||
original_config.pop("speech_lora", None)
|
||||
original_config.pop("transformers_version", None)
|
||||
original_config.pop("_attn_implementation", None)
|
||||
|
||||
embd_layer = original_config.pop("embd_layer")
|
||||
audio_embd_layer = embd_layer["audio_embd_layer"]
|
||||
vision_embd_layer = embd_layer["image_embd_layer"]
|
||||
|
||||
# Keep only some of the subdict
|
||||
keep_audio_embd_layer = ["downsample_rate"]
|
||||
keep_vision_embd_layer = ["crop_size"]
|
||||
audio_embd_layer = {k: v for k, v in audio_embd_layer.items() if k in keep_audio_embd_layer}
|
||||
vision_embd_layer = {k: v for k, v in vision_embd_layer.items() if k in keep_vision_embd_layer}
|
||||
|
||||
audio_config = original_config.pop("audio_processor")["config"]
|
||||
# remove
|
||||
audio_config.pop("activation_checkpointing", None)
|
||||
audio_config.pop("cnn_layer_norm", None)
|
||||
audio_config.pop("input_layer", None)
|
||||
audio_config.pop("batch_norm", None)
|
||||
audio_config.pop("encoder_embedding_config", None)
|
||||
audio_config.pop("ext_pw_kernel_size", None)
|
||||
audio_config.pop("bias_in_glu", None)
|
||||
audio_config.pop("causal", None)
|
||||
# rename
|
||||
audio_config["hidden_size"] = audio_config.pop("attention_dim")
|
||||
audio_config["num_attention_heads"] = audio_config.pop("attention_heads")
|
||||
audio_config["intermediate_size"] = audio_config.pop("linear_units")
|
||||
audio_config["nemo_conv_channels"] = audio_config.pop("nemo_conv_settings")["conv_channels"]
|
||||
audio_config["bias_max_distance"] = audio_config.pop("relative_attention_bias_args")["t5_bias_max_distance"]
|
||||
# add
|
||||
audio_config = {**audio_config, **audio_embd_layer}
|
||||
|
||||
# Create transformers config objects
|
||||
audio_config = Phi4MultimodalAudioConfig(**audio_config)
|
||||
vision_config = Phi4MultimodalVisionConfig(**vision_embd_layer)
|
||||
|
||||
# Add 2nd eos to config
|
||||
original_config["eos_token_id"] = [199999, 200020]
|
||||
|
||||
new_config = Phi4MultimodalConfig(**original_config, vision_config=vision_config, audio_config=audio_config)
|
||||
return new_config
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def convert_and_write_model(input_dir: str, output_dir: str):
|
||||
"""Convert the model and save it (this implicitly save the config as well)."""
|
||||
original_config = read_json(os.path.join(input_dir, "config.json"))
|
||||
config = convert_config(original_config)
|
||||
|
||||
full_state_dict = {}
|
||||
shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
|
||||
for shard_file in shards:
|
||||
original_state_dict = load_file(os.path.join(input_dir, shard_file))
|
||||
new_dict = convert_state_dict(original_state_dict)
|
||||
full_state_dict.update(new_dict)
|
||||
|
||||
# Load weights into model and resave them
|
||||
with torch.device("meta"):
|
||||
model = Phi4MultimodalForCausalLM(config)
|
||||
missing, unexpected = model.load_state_dict(full_state_dict, strict=False, assign=True)
|
||||
# The lm_head is missing because it's tied
|
||||
if missing != ["lm_head.weight"]:
|
||||
raise ValueError("Missing keys:\n{missing}")
|
||||
if len(unexpected) > 0:
|
||||
raise ValueError(f"Unexpected keys:\n{unexpected}")
|
||||
|
||||
model.tie_weights()
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def convert_and_save_processor(input_dir: str, output_dir: str):
|
||||
"""Convert the processor."""
|
||||
processor = Phi4MultimodalProcessor.from_pretrained(input_dir)
|
||||
del processor.image_processor.auto_map
|
||||
del processor.audio_processor.auto_map
|
||||
processor.chat_template = processor.tokenizer.chat_template
|
||||
processor.tokenizer.extra_special_tokens = {"image_token": "<|endoftext10|>", "audio_token": "<|endoftext11|>"}
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def extract_adapters_data(input_dir: str, output_dir: str):
|
||||
"""Extract adapters data from the state dict and save weights and configs."""
|
||||
speech_lora = {}
|
||||
vision_lora = {}
|
||||
shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
|
||||
for shard_file in shards:
|
||||
original_state_dict = load_file(os.path.join(input_dir, shard_file))
|
||||
for k, v in original_state_dict.items():
|
||||
if "lora" in k:
|
||||
if "speech" in k:
|
||||
speech_lora[k.replace("speech.", "")] = v
|
||||
elif "vision" in k:
|
||||
vision_lora[k.replace("vision.", "")] = v
|
||||
|
||||
# Create and save the lora configs
|
||||
speech_lora_config = LoraConfig(
|
||||
r=320,
|
||||
lora_alpha=640,
|
||||
target_modules=r"model.layers.\d+.((self_attn.(qkv|o)_proj)|(mlp.(gate_up|down)_proj))",
|
||||
lora_dropout=0.01,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
speech_lora_config.save_pretrained(os.path.join(output_dir, "speech-lora"))
|
||||
vision_lora_config = LoraConfig(
|
||||
r=256,
|
||||
lora_alpha=512,
|
||||
target_modules=r"model.layers.\d+.((self_attn.(qkv|o)_proj)|(mlp.(gate_up|down)_proj))",
|
||||
lora_dropout=0.0,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
vision_lora_config.save_pretrained(os.path.join(output_dir, "vision-lora"))
|
||||
|
||||
save_file(speech_lora, os.path.join(output_dir, "speech-lora", "adapter_model.safetensors"))
|
||||
save_file(vision_lora, os.path.join(output_dir, "vision-lora", "adapter_model.safetensors"))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"input_dir",
|
||||
help="Location of the model folder containing the weights and configs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir",
|
||||
help="Location to write HF model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert
|
||||
convert_and_write_model(args.input_dir, args.output_dir)
|
||||
convert_and_save_processor(args.input_dir, args.output_dir)
|
||||
extract_adapters_data(args.input_dir, args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,348 @@
|
||||
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Processor class for Phi4Multimodal
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...audio_utils import AudioInput
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...utils import TensorType, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: @eustlb, remove this once #36603 is merged.
|
||||
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
|
||||
"""Create a Mel filter-bank the same as SpeechLib FbankFC.
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sample rate in Hz. number > 0 [scalar]
|
||||
n_fft (int): FFT size. int > 0 [scalar]
|
||||
n_mel (int): Mel filter size. int > 0 [scalar]
|
||||
fmin (float): lowest frequency (in Hz). If None use 0.0.
|
||||
float >= 0 [scalar]
|
||||
fmax: highest frequency (in Hz). If None use sample_rate / 2.
|
||||
float >= 0 [scalar]
|
||||
|
||||
Returns
|
||||
out (numpy.ndarray): Mel transform matrix
|
||||
[shape=(n_mels, 1 + n_fft/2)]
|
||||
"""
|
||||
|
||||
bank_width = int(n_fft // 2 + 1)
|
||||
if fmax is None:
|
||||
fmax = sample_rate / 2
|
||||
if fmin is None:
|
||||
fmin = 0
|
||||
assert fmin >= 0, "fmin cannot be negtive"
|
||||
assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]"
|
||||
|
||||
def mel(f):
|
||||
return 1127.0 * np.log(1.0 + f / 700.0)
|
||||
|
||||
def bin2mel(fft_bin):
|
||||
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
|
||||
|
||||
def f2bin(f):
|
||||
return int((f * n_fft / sample_rate) + 0.5)
|
||||
|
||||
# Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1]
|
||||
klo = f2bin(fmin) + 1
|
||||
khi = f2bin(fmax)
|
||||
|
||||
khi = max(khi, klo)
|
||||
|
||||
# Spec 2: SpeechLib uses trianges in Mel space
|
||||
mlo = mel(fmin)
|
||||
mhi = mel(fmax)
|
||||
m_centers = np.linspace(mlo, mhi, n_mels + 2)
|
||||
ms = (mhi - mlo) / (n_mels + 1)
|
||||
|
||||
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
|
||||
for m in range(0, n_mels):
|
||||
left = m_centers[m]
|
||||
center = m_centers[m + 1]
|
||||
right = m_centers[m + 2]
|
||||
for fft_bin in range(klo, khi):
|
||||
mbin = bin2mel(fft_bin)
|
||||
if left < mbin < right:
|
||||
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
|
||||
model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_size: int = 80,
|
||||
sampling_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
n_fft: int = 512,
|
||||
win_length: int = 400,
|
||||
preemphasis: float = 0.97,
|
||||
padding_value: float = 0.0,
|
||||
audio_compression_rate: int = 8,
|
||||
audio_downsample_rate: int = 1,
|
||||
audio_feat_stride: int = 1,
|
||||
mel_min_frequency: float = 0,
|
||||
mel_max_frequency: float = 7690,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
||||
|
||||
self.hop_length = hop_length
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.preemphasis = preemphasis
|
||||
self.padding_value = padding_value
|
||||
self.audio_compression_rate = audio_compression_rate
|
||||
self.audio_downsample_rate = audio_downsample_rate
|
||||
self.audio_feat_stride = audio_feat_stride
|
||||
|
||||
# TODO: @eustlb, uncomment and remove speechlib_mel once #36603 is merged.
|
||||
# self.mel_filters = mel_filter_bank(
|
||||
# num_frequency_bins=self.n_fft // 2 + 1,
|
||||
# num_mel_filters=self.feature_size,
|
||||
# min_frequency=mel_min_frequency,
|
||||
# max_frequency=mel_max_frequency,
|
||||
# sampling_rate=self.sampling_rate,
|
||||
# triangularize_in_mel_space=True,
|
||||
# mel_scale="kaldi",
|
||||
# )
|
||||
self.mel_filters = speechlib_mel(
|
||||
self.sampling_rate, self.n_fft, self.feature_size, mel_min_frequency, mel_max_frequency
|
||||
).T
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_speech: AudioInput,
|
||||
sampling_rate: Optional[int] = None,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
padding: Optional[str] = "longest",
|
||||
max_length: Optional[int] = None,
|
||||
truncation: bool = False,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_attention_mask: Optional[bool] = True,
|
||||
device: Optional[str] = "cpu",
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to featurize and prepare for the model one or several audio sequence(s). Implementation uses PyTorch for
|
||||
the STFT computation if available, otherwise a slower NumPy based one.
|
||||
|
||||
Args:
|
||||
raw_speech (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The sequence or batch of sequences to be processed. Each sequence can be a numpy array or PyTorch tensor.
|
||||
For batched inputs, sequences can be a list of numpy arrays or PyTorch tensors, or a single numpy array or
|
||||
PyTorch tensor with first dimension being the batch size.
|
||||
sampling_rate (`int`, *optional*):
|
||||
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
|
||||
`sampling_rate` at the forward call to prevent silent errors.
|
||||
pad_to_multiple_of (`int`, *optional*, defaults to None):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
padding (`str`, *optional*, defaults to "longest"):
|
||||
Padding strategy. Can be "longest" to pad to the longest sequence in the batch, or a specific length.
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length.
|
||||
truncation (`bool`, *optional*, defaults to False):
|
||||
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of numpy arrays. Acceptable values are:
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return Numpy `np.ndarray` objects.
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
return_attention_mask (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the extracted audio input features' attention mask.
|
||||
device (`str`, *optional*, defaults to "cpu"):
|
||||
Specifies the device for computation of the audio features. (e.g., "cpu", "cuda")
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
- **audio_input_features** -- Audio features extracted from the raw audio input, shape (batch_size, max_feature_length, feature_size).
|
||||
- **audio_lengths** -- Length of each audio sample in the batch, shape (batch_size,).
|
||||
- **audio_attention_mask** -- Attention mask for the audio input, shape (batch_size, max_feature_length).
|
||||
If `return_tensors` is not specified, the fields will be PyTorch tensors if PyTorch is available, otherwise NumPy arrays.
|
||||
"""
|
||||
if sampling_rate is not None:
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
|
||||
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
|
||||
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
# Convert to torch tensor
|
||||
if isinstance(raw_speech, np.ndarray):
|
||||
raw_speech = torch.tensor(raw_speech)
|
||||
elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray):
|
||||
raw_speech = [torch.tensor(speech) for speech in raw_speech]
|
||||
|
||||
is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
|
||||
if is_batched_torch and len(raw_speech.shape) > 2:
|
||||
logger.warning(
|
||||
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
|
||||
"We will take the mean of the channels to convert to mono."
|
||||
)
|
||||
raw_speech = raw_speech.mean(-1)
|
||||
|
||||
is_batched_sequence = isinstance(raw_speech, (list, tuple))
|
||||
if is_batched_sequence:
|
||||
for speech in raw_speech:
|
||||
if len(speech.shape) > 1:
|
||||
logger.warning(
|
||||
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
|
||||
"We will take the mean of the channels to convert to mono."
|
||||
)
|
||||
speech = speech.mean(-1)
|
||||
|
||||
if is_batched_torch or is_batched_sequence:
|
||||
raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
|
||||
else:
|
||||
raw_speech = [raw_speech[:, None].to(torch.float32)]
|
||||
|
||||
audio_lengths = [len(speech) for speech in raw_speech]
|
||||
|
||||
# convert into correct format for padding
|
||||
batched_speech = BatchFeature(data={"audio_input_features": raw_speech, "audio_lengths": audio_lengths})
|
||||
padded_inputs = self.pad(
|
||||
batched_speech,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_features = padded_inputs.audio_input_features.squeeze(-1)
|
||||
audio_lengths = padded_inputs.audio_lengths
|
||||
|
||||
input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device)
|
||||
|
||||
feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1
|
||||
feature_lengths = feature_lengths * self.audio_feat_stride
|
||||
audio_embed_sizes = self._compute_audio_embed_size(feature_lengths)
|
||||
|
||||
feature_attention_mask = (
|
||||
torch.arange(0, feature_lengths.max()) if is_torch_available() else np.arange(0, feature_lengths.max())
|
||||
)
|
||||
feature_attention_mask = (
|
||||
feature_attention_mask[None, :] < feature_lengths[:, None] if len(feature_lengths) > 1 else None
|
||||
)
|
||||
|
||||
data = {
|
||||
"audio_input_features": input_features,
|
||||
"audio_embed_sizes": audio_embed_sizes,
|
||||
}
|
||||
if feature_attention_mask is not None and return_attention_mask:
|
||||
data["audio_attention_mask"] = feature_attention_mask
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
# TODO; @eustlb, move this to audio_utils in a general spectogram_batch function that handles torch and numpy
|
||||
def _torch_extract_fbank_features(
|
||||
self, waveform: "torch.FloatTensor", audio_lengths: "torch.Tensor", device: str = "cpu"
|
||||
) -> "torch.FloatTensor":
|
||||
"""
|
||||
Compute the log mel-scaled spectrogram of batched waveforms using PyTorch's FFT implementation.
|
||||
|
||||
Args:
|
||||
waveform (torch.FloatTensor` of shape `(batch_size, max_audio_length)`):
|
||||
The batched waveforms.
|
||||
audio_lengths (`torch.Tensor` of shape `(batch_size,)`):
|
||||
The lengths of the waveforms along the max_audio_length dimension.
|
||||
device (`str`, *optional*, defaults to "cpu"):
|
||||
The device to run the computation on. (e.g., "cpu", "cuda")
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch_size, max_feature_length, feature_size)`:
|
||||
The log mel-scaled spectrogram of the batched waveforms.
|
||||
"""
|
||||
fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64)
|
||||
|
||||
# batched implementation
|
||||
batch_size = waveform.shape[0]
|
||||
frames = waveform.unfold(-1, self.win_length, self.hop_length)
|
||||
|
||||
# ---
|
||||
# the unbatched (and unpaded) original implementation skips last few audio values that can't be included in a frame
|
||||
# we need to ensure that the corresponding frames for the padded input also mask these values
|
||||
if batch_size > 1:
|
||||
frames = frames.clone()
|
||||
# concerned batch indices
|
||||
to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()]
|
||||
if to_mask_batch_idxs.numel() > 0:
|
||||
batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1
|
||||
batch_idxs_up = audio_lengths[to_mask_batch_idxs] // self.hop_length + 1
|
||||
offset_idx = batch_idxs_down.min()
|
||||
max_idx = batch_idxs_up.max()
|
||||
|
||||
mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1)
|
||||
mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & (
|
||||
mask < (batch_idxs_up - offset_idx).unsqueeze(1)
|
||||
)
|
||||
mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length)
|
||||
masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0)
|
||||
frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames
|
||||
# ---
|
||||
|
||||
# apply pre-emphasis first order filter on fft windows
|
||||
frames_prev = torch.roll(frames, 1, dims=-1)
|
||||
frames_prev[:, :, 0] = frames_prev[:, :, 1]
|
||||
frames = (frames - self.preemphasis * frames_prev) * 32768
|
||||
|
||||
# apply fft
|
||||
S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1)
|
||||
S = S.view(frames.shape[0], -1, S.shape[-1])
|
||||
S = S.to(torch.complex64)
|
||||
|
||||
spec = torch.abs(S)
|
||||
spec_power = spec**2
|
||||
|
||||
# apply triangular mel filter bank
|
||||
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
|
||||
log_spec = torch.clamp(spec_power @ mel_filters, min=1.0)
|
||||
log_spec = torch.log(log_spec)
|
||||
|
||||
return log_spec
|
||||
|
||||
def _compute_audio_embed_size(self, audio_frames):
|
||||
integer = audio_frames // self.audio_compression_rate
|
||||
remainder = audio_frames % self.audio_compression_rate
|
||||
result = integer + (remainder > 0).to(integer.dtype)
|
||||
|
||||
integer = result // self.audio_downsample_rate
|
||||
remainder = result % self.audio_downsample_rate
|
||||
result = integer + (remainder > 0).to(integer.dtype) # qformer compression
|
||||
|
||||
return result
|
||||
|
||||
|
||||
__all__ = ["Phi4MultimodalFeatureExtractor"]
|
@ -0,0 +1,263 @@
|
||||
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Processor class for Phi4Multimodal
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
BatchFeature,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
Unpack,
|
||||
convert_to_rgb,
|
||||
)
|
||||
from ...image_utils import ImageInput, make_list_of_images, valid_images
|
||||
from ...utils import TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Phi4MultimodalFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
image_size: Optional[int]
|
||||
patch_size: Optional[int]
|
||||
dynamic_hd: Optional[int]
|
||||
|
||||
|
||||
class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast):
|
||||
r"""
|
||||
Constructs a Phi4Multimodal image processor.
|
||||
"""
|
||||
|
||||
image_size = 448
|
||||
patch_size = 14
|
||||
dynamic_hd = 36
|
||||
image_mean = [0.5, 0.5, 0.5]
|
||||
image_std = [0.5, 0.5, 0.5]
|
||||
valid_init_kwargs = Phi4MultimodalFastImageProcessorKwargs
|
||||
model_input_names = ["image_pixel_values", "image_sizes", "image_attention_mask"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Phi4MultimodalFastImageProcessorKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height):
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * self.image_size * self.image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
def dynamic_preprocess(self, image, max_num=36, min_num=1):
|
||||
image_size = self.image_size
|
||||
patch_size = self.patch_size
|
||||
mask_size = image_size // patch_size
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
w_crop_num = math.ceil(orig_width / float(image_size))
|
||||
h_crop_num = math.ceil(orig_height / float(image_size))
|
||||
if w_crop_num * h_crop_num > max_num:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = self.find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
else:
|
||||
target_width = image_size * w_crop_num
|
||||
target_height = image_size * h_crop_num
|
||||
target_aspect_ratio = (w_crop_num, h_crop_num)
|
||||
|
||||
# Calculate the ratio
|
||||
ratio_width = target_width / orig_width
|
||||
ratio_height = target_height / orig_height
|
||||
if ratio_width < ratio_height:
|
||||
new_size = (target_width, int(orig_height * ratio_width))
|
||||
padding_width = 0
|
||||
padding_height = target_height - int(orig_height * ratio_width)
|
||||
else:
|
||||
new_size = (int(orig_width * ratio_height), target_height)
|
||||
padding_width = target_width - int(orig_width * ratio_height)
|
||||
padding_height = 0
|
||||
|
||||
attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), int(mask_size * target_aspect_ratio[0])))
|
||||
if padding_width >= patch_size:
|
||||
attention_mask[:, -math.floor(padding_width / patch_size) :] = 0
|
||||
if padding_height >= patch_size:
|
||||
attention_mask[-math.floor(padding_height / patch_size) :, :] = 0
|
||||
|
||||
if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10:
|
||||
raise ValueError(f"the aspect ratio is very extreme {new_size}")
|
||||
|
||||
image = F.resize(image, [new_size[1], new_size[0]])
|
||||
resized_img = F.pad(image, [0, 0, padding_width, padding_height], fill=[255, 255, 255])
|
||||
|
||||
return resized_img, attention_mask
|
||||
|
||||
def pad_to_max_num_crops(self, images, max_crops=5):
|
||||
"""
|
||||
images: B x 3 x H x W, B<=max_crops
|
||||
"""
|
||||
B, _, H, W = images.shape
|
||||
if B < max_crops:
|
||||
pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device)
|
||||
images = torch.cat([images, pad], dim=0)
|
||||
return images
|
||||
|
||||
def pad_mask_to_max_num_crops(self, masks, max_crops=5):
|
||||
B, H, W = masks.shape
|
||||
if B < max_crops:
|
||||
pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device)
|
||||
masks = torch.cat([masks, pad], dim=0)
|
||||
return masks
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
"""
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
|
||||
images = make_list_of_images(images)
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
image_size = self.image_size
|
||||
patch_size = self.patch_size
|
||||
mask_size = image_size // patch_size
|
||||
imgs_and_masks = [self.dynamic_preprocess(image, max_num=self.dynamic_hd) for image in images]
|
||||
images, image_attention_masks = [x[0] for x in imgs_and_masks], [x[1] for x in imgs_and_masks]
|
||||
|
||||
images = [F.to_tensor(image) for image in images]
|
||||
hd_images = [F.normalize(image, image_mean, image_std) for image in images]
|
||||
global_image = [
|
||||
torch.nn.functional.interpolate(
|
||||
image.unsqueeze(0).float(),
|
||||
size=(image_size, image_size),
|
||||
mode="bicubic",
|
||||
).to(image.dtype)
|
||||
for image in hd_images
|
||||
]
|
||||
|
||||
shapes = [[image.size(1), image.size(2)] for image in hd_images]
|
||||
mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks]
|
||||
global_attention_mask = [torch.ones((1, mask_size, mask_size)) for _ in hd_images]
|
||||
|
||||
hd_images_reshape = []
|
||||
for im, (h, w) in zip(hd_images, shapes):
|
||||
im = im.reshape(1, 3, h // image_size, image_size, w // image_size, image_size)
|
||||
im = im.permute(0, 2, 4, 1, 3, 5)
|
||||
im = im.reshape(-1, 3, image_size, image_size)
|
||||
hd_images_reshape.append(im.contiguous())
|
||||
|
||||
attention_masks_reshape = []
|
||||
for mask, (h, w) in zip(image_attention_masks, mask_shapes):
|
||||
mask = mask.reshape(h // mask_size, mask_size, w // mask_size, mask_size)
|
||||
mask = mask.transpose(1, 2)
|
||||
mask = mask.reshape(-1, mask_size, mask_size)
|
||||
attention_masks_reshape.append(mask.contiguous())
|
||||
|
||||
downsample_attention_masks = []
|
||||
for mask, (h, w) in zip(attention_masks_reshape, mask_shapes):
|
||||
mask = mask[:, 0::2, 0::2]
|
||||
mask = mask.reshape(
|
||||
h // mask_size, w // mask_size, mask_size // 2 + mask_size % 2, mask_size // 2 + mask_size % 2
|
||||
)
|
||||
mask = mask.transpose(1, 2)
|
||||
mask = mask.reshape(mask.size(0) * mask.size(1), mask.size(2) * mask.size(3))
|
||||
downsample_attention_masks.append(mask)
|
||||
|
||||
num_img_tokens = [
|
||||
256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 for mask in downsample_attention_masks
|
||||
]
|
||||
|
||||
hd_images_reshape = [
|
||||
torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape)
|
||||
]
|
||||
hd_masks_reshape = [
|
||||
torch.cat([_global_mask] + [_mask], dim=0)
|
||||
for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape)
|
||||
]
|
||||
max_crops = max([img.size(0) for img in hd_images_reshape])
|
||||
image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape]
|
||||
image_transformed = torch.stack(image_transformed, dim=0)
|
||||
mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape]
|
||||
mask_transformed = torch.stack(mask_transformed, dim=0)
|
||||
|
||||
returned_input_image_embeds = image_transformed
|
||||
returned_image_sizes = torch.tensor(shapes, dtype=torch.long)
|
||||
returned_image_attention_mask = mask_transformed
|
||||
returned_num_img_tokens = num_img_tokens
|
||||
|
||||
data = {
|
||||
"image_pixel_values": returned_input_image_embeds,
|
||||
"image_sizes": returned_image_sizes,
|
||||
"image_attention_mask": returned_image_attention_mask,
|
||||
"num_img_tokens": returned_num_img_tokens,
|
||||
}
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Phi4MultimodalImageProcessorFast"]
|
2316
src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py
Normal file
2316
src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py
Normal file
File diff suppressed because it is too large
Load Diff
1851
src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py
Normal file
1851
src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,194 @@
|
||||
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Processor class for Phi4Multimodal
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ...audio_utils import AudioInput
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Phi4MultimodalProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"audio_kwargs": {
|
||||
"device": "cpu",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Phi4MultimodalProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Phi4Multimodal processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor.
|
||||
|
||||
[`Phi4MultimodalProcessor`] offers all the functionalities of [`Phi4MultimodalImageProcessorFast`] and [`GPT2Tokenizer`]. See the
|
||||
[`~Phi4MultimodalProcessor.__call__`] and [`~Phi4MultimodalProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
image_processor (`Phi4MultimodalImageProcessorFast`):
|
||||
The image processor to use for images.
|
||||
audio_processor (`Phi4MultimodalFeatureExtractor`):
|
||||
The audio processor to use for audio inputs.
|
||||
tokenizer (`GPT2TokenizerFast`):
|
||||
The tokenizer to use for text.
|
||||
fake_image_token_pattern (`str`, *optional*, defaults to `r"<\|image_\d+\|>"`):
|
||||
The fake image token pattern.
|
||||
fake_audio_token_pattern (`str`, *optional*, defaults to `r"<\|audio_\d+\|>"`):
|
||||
The fake audio token pattern.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "audio_processor", "tokenizer"]
|
||||
tokenizer_class = "GPT2TokenizerFast"
|
||||
image_processor_class = "Phi4MultimodalImageProcessorFast"
|
||||
audio_processor_class = "Phi4MultimodalFeatureExtractor"
|
||||
valid_kwargs = ["chat_template", "fake_image_token_pattern", "fake_audio_token_pattern"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor,
|
||||
audio_processor,
|
||||
tokenizer,
|
||||
fake_image_token_pattern: str = r"<\|image_\d+\|>",
|
||||
fake_audio_token_pattern: str = r"<\|audio_\d+\|>",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(image_processor, audio_processor, tokenizer, **kwargs)
|
||||
self.fake_image_token_pattern = fake_image_token_pattern
|
||||
self.fake_audio_token_pattern = fake_audio_token_pattern
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, List[TextInput]],
|
||||
images: Optional[ImageInput] = None,
|
||||
audios: Optional[AudioInput] = None,
|
||||
**kwargs: Unpack[ProcessingKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text`
|
||||
and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
|
||||
Phi4MultimodalImageProcessorFast's [`~Phi4MultimodalImageProcessorFast.__call__`] if `images` is not `None`. Please refer to the doctsring
|
||||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
audios (`List[Union[np.ndarray, torch.Tensor]]`):
|
||||
List of the audios to be prepared.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
|
||||
- **input_image_embeds** -- Pixel values to be fed to a model.
|
||||
- **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`.
|
||||
- **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`.
|
||||
- **input_audio_embeds** -- Audio embeddings to be fed to a model.
|
||||
- **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`.
|
||||
"""
|
||||
|
||||
output_kwargs = self._merge_kwargs(Phi4MultimodalProcessorKwargs, self.tokenizer.init_kwargs, **kwargs)
|
||||
image_kwargs = output_kwargs["images_kwargs"]
|
||||
audio_kwargs = output_kwargs["audio_kwargs"]
|
||||
text_kwargs = output_kwargs["text_kwargs"]
|
||||
|
||||
image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
|
||||
audio_inputs = self.audio_processor(audios, **audio_kwargs) if audios is not None else {}
|
||||
|
||||
# We pop here for images as we don't need it later
|
||||
num_img_tokens = image_inputs.pop("num_img_tokens", [])
|
||||
audio_embed_sizes = audio_inputs.get("audio_embed_sizes", [])
|
||||
|
||||
# Replace certain special tokens for compatibility
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not isinstance(text, list) and not isinstance(text[0], str):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
image_token = self.tokenizer.image_token
|
||||
audio_token = self.tokenizer.audio_token
|
||||
processed_text = [re.sub(self.fake_image_token_pattern, image_token, t) for t in text]
|
||||
processed_text = [re.sub(self.fake_audio_token_pattern, audio_token, t) for t in processed_text]
|
||||
|
||||
# Check that the number of special tokens is sound
|
||||
concatenated_prompt = "".join(processed_text)
|
||||
if concatenated_prompt.count(self.tokenizer.image_token) != len(num_img_tokens):
|
||||
raise ValueError(
|
||||
"You should add as much image tokens `<|image_i|>` in your prompt as you pass `images` to the processor"
|
||||
)
|
||||
if concatenated_prompt.count(self.tokenizer.audio_token) != len(audio_embed_sizes):
|
||||
raise ValueError(
|
||||
"You should add as much audio tokens `<|audio_i|>` in your prompt as you pass `audios` to the processor"
|
||||
)
|
||||
|
||||
# Add appropriate number of image/audio tokens (note that the count of replacement is dynamic)
|
||||
image_count_iter = iter(num_img_tokens)
|
||||
audio_count_iter = iter(audio_embed_sizes)
|
||||
processed_text = [
|
||||
re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in processed_text
|
||||
]
|
||||
processed_text = [
|
||||
re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(processed_text, **text_kwargs)
|
||||
|
||||
# prepare batch feature
|
||||
data = {
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
**audio_inputs,
|
||||
}
|
||||
|
||||
return BatchFeature(data=data)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
audio_processor_input_names = self.audio_processor.model_input_names
|
||||
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names))
|
||||
|
||||
|
||||
__all__ = ["Phi4MultimodalProcessor"]
|
@ -7746,6 +7746,55 @@ class Phi3PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Phi4MultimodalAudioModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Phi4MultimodalAudioPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Phi4MultimodalForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Phi4MultimodalModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Phi4MultimodalPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Phi4MultimodalVisionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Phi4MultimodalVisionPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class PhimoeForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -93,6 +93,13 @@ class LlavaOnevisionImageProcessorFast(metaclass=DummyObject):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class Phi4MultimodalImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torchvision"])
|
||||
|
||||
|
||||
class PixtralImageProcessorFast(metaclass=DummyObject):
|
||||
_backends = ["torchvision"]
|
||||
|
||||
|
0
tests/models/phi4_multimodal/__init__.py
Normal file
0
tests/models/phi4_multimodal/__init__.py
Normal file
405
tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py
Normal file
405
tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py
Normal file
@ -0,0 +1,405 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoProcessor,
|
||||
GenerationConfig,
|
||||
Phi4MultimodalAudioConfig,
|
||||
Phi4MultimodalConfig,
|
||||
Phi4MultimodalForCausalLM,
|
||||
Phi4MultimodalModel,
|
||||
Phi4MultimodalVisionConfig,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_soundfile,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_soundfile_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
if is_soundfile_available():
|
||||
import soundfile
|
||||
|
||||
|
||||
class Phi4MultimodalModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=2,
|
||||
seq_length=12,
|
||||
image_seq_length=275,
|
||||
audio_seq_length=8,
|
||||
is_training=True,
|
||||
num_hidden_layers=2,
|
||||
vocab_size=49,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=4,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
pad_token_id=0,
|
||||
image_token_id=1,
|
||||
audio_token_id=2,
|
||||
image_size=16,
|
||||
audio_size=12,
|
||||
audio_config=Phi4MultimodalAudioConfig(
|
||||
num_blocks=2,
|
||||
hidden_size=32,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=48,
|
||||
depthwise_seperable_out_channel=128,
|
||||
nemo_conv_channels=128,
|
||||
),
|
||||
vision_config=Phi4MultimodalVisionConfig(
|
||||
num_hidden_layers=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
num_attention_heads=8,
|
||||
crop_size=16,
|
||||
),
|
||||
):
|
||||
self.parent = parent
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.bos_token_id = bos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.image_token_id = image_token_id
|
||||
self.audio_token_id = audio_token_id
|
||||
self.audio_config = audio_config
|
||||
self.vision_config = vision_config
|
||||
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length + image_seq_length + audio_seq_length
|
||||
self.image_seq_length = image_seq_length
|
||||
self.audio_seq_length = audio_seq_length
|
||||
self.image_size = image_size
|
||||
self.audio_size = audio_size
|
||||
self.num_channels = 3
|
||||
|
||||
def get_config(self):
|
||||
return Phi4MultimodalConfig(
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
vision_config=self.vision_config,
|
||||
audio_config=self.audio_config,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
# The shapes corresponds to the inputs for image of size 16x16
|
||||
image_pixel_values = floats_tensor([self.batch_size, 2, self.num_channels, self.image_size, self.image_size])
|
||||
image_attention_mask = torch.ones(self.batch_size, 2, 1, 1)
|
||||
image_sizes = torch.tensor(
|
||||
[[self.image_size, self.image_size]] * self.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
# Feature sizes returned by an audio of size 10000
|
||||
audio_input_features = floats_tensor([self.batch_size, 61, 80])
|
||||
audio_embed_sizes = torch.tensor([self.audio_seq_length] * self.batch_size, dtype=torch.long)
|
||||
|
||||
input_ids[input_ids == self.pad_token_id] = self.pad_token_id + 1 # random value but not pad token
|
||||
input_ids[-1, 0] = self.pad_token_id # mask the last text token
|
||||
input_ids[:, -self.image_seq_length - self.audio_seq_length : -self.audio_seq_length] = self.image_token_id
|
||||
input_ids[:, -self.audio_seq_length :] = self.audio_token_id
|
||||
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask[-1, 0] = 0 # mask the last text token
|
||||
config = self.get_config()
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
image_pixel_values,
|
||||
image_attention_mask,
|
||||
image_sizes,
|
||||
audio_input_features,
|
||||
audio_embed_sizes,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
image_pixel_values,
|
||||
image_attention_mask,
|
||||
image_sizes,
|
||||
audio_input_features,
|
||||
audio_embed_sizes,
|
||||
) = self.prepare_config_and_inputs()
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"image_pixel_values": image_pixel_values,
|
||||
"image_attention_mask": image_attention_mask,
|
||||
"image_sizes": image_sizes,
|
||||
"audio_input_features": audio_input_features,
|
||||
"audio_embed_sizes": audio_embed_sizes,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask):
|
||||
model = Phi4MultimodalForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
logits = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
self.parent.assertEqual(logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
self.parent.assertFalse(torch.isnan(logits).any().item())
|
||||
|
||||
|
||||
@require_torch
|
||||
class Phi4MultimodalModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `Phi4Multimodal`.
|
||||
"""
|
||||
|
||||
all_model_classes = (Phi4MultimodalForCausalLM, Phi4MultimodalModel) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Phi4MultimodalModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Phi4MultimodalConfig)
|
||||
|
||||
@unittest.skip(reason="Unstable test")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Right padding not supported")
|
||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="This one tries to use right padding as well")
|
||||
def test_eager_matches_fa2_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Depending on input modalities, some params may not have gradients")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Depending on input modalities, some params may not have gradients")
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Depending on input modalities, some params may not have gradients")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test tries to instantiate dynamic cache with an arg")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test is only for old attention format")
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Static cache supported only for text-only inputs (not images or audios)")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Static cache supported only for text-only inputs (not images or audios)")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Supported only for text-only inputs (otherwise dynamic control flows for multimodal inputs)"
|
||||
)
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Supported only for text-only inputs (otherwise dynamic control flows for multimodal inputs)"
|
||||
)
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@unittest.skip(reason="`image_attention_mask` has a specific shape")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="`image_attention_mask` has a specific shape")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="`image_attention_mask` has a specific shape")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Cannot unpad inputs for all modalities so easily")
|
||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Dynamo error")
|
||||
def test_flex_attention_with_grads(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
class Phi4MultimodalIntegrationTest(unittest.TestCase):
|
||||
checkpoint_path = "microsoft/Phi-4-multimodal-instruct"
|
||||
image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"
|
||||
|
||||
def setUp(self):
|
||||
self.processor = AutoProcessor.from_pretrained(self.checkpoint_path)
|
||||
self.generation_config = GenerationConfig(max_new_tokens=20, do_sample=False)
|
||||
self.user_token = "<|user|>"
|
||||
self.assistant_token = "<|assistant|>"
|
||||
self.end_token = "<|end|>"
|
||||
self.image = Image.open(requests.get(self.image_url, stream=True).raw)
|
||||
with tempfile.NamedTemporaryFile(mode="w+b", suffix=".wav") as tmp:
|
||||
tmp.write(requests.get(self.audio_url, stream=True).raw.data)
|
||||
tmp.flush()
|
||||
tmp.seek(0)
|
||||
self.audio, self.sampling_rate = soundfile.read(tmp.name)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_text_only_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
|
||||
prompt = f"{self.user_token}What is the answer for 1+1? Explain it.{self.end_token}{self.assistant_token}"
|
||||
inputs = self.processor(prompt, images=None, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
output = output[:, inputs["input_ids"].shape[1] :]
|
||||
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
|
||||
EXPECTED_RESPONSE = "The answer for 1+1 is 2. This is because when you add one to another"
|
||||
|
||||
self.assertEqual(response, EXPECTED_RESPONSE)
|
||||
|
||||
def test_vision_text_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
|
||||
prompt = f"{self.user_token}<|image_1|>What is shown in this image?{self.end_token}{self.assistant_token}"
|
||||
inputs = self.processor(prompt, images=self.image, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
output = output[:, inputs["input_ids"].shape[1] :]
|
||||
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
|
||||
EXPECTED_RESPONSE = "The image shows a vibrant scene at a street intersection in a city with a Chinese-influenced architectural"
|
||||
|
||||
self.assertEqual(response, EXPECTED_RESPONSE)
|
||||
|
||||
def test_multi_image_vision_text_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
|
||||
images = []
|
||||
placeholder = ""
|
||||
for i in range(1, 5):
|
||||
url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg"
|
||||
images.append(Image.open(requests.get(url, stream=True).raw))
|
||||
placeholder += f"<|image_{i}|>"
|
||||
|
||||
prompt = f"{self.user_token}{placeholder}Summarize the deck of slides.{self.end_token}{self.assistant_token}"
|
||||
inputs = self.processor(prompt, images, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
output = output[:, inputs["input_ids"].shape[1] :]
|
||||
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
|
||||
EXPECTED_RESPONSE = "The presentation provides an overview of Microsoft Azure, a cloud computing platform by Microsoft, and its various services"
|
||||
|
||||
self.assertEqual(response, EXPECTED_RESPONSE)
|
||||
|
||||
@require_soundfile
|
||||
def test_audio_text_generation(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device
|
||||
)
|
||||
|
||||
prompt = f"{self.user_token}<|audio_1|>What is happening in this audio?{self.end_token}{self.assistant_token}"
|
||||
inputs = self.processor(prompt, audios=self.audio, sampling_rate=self.sampling_rate, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
output = output[:, inputs["input_ids"].shape[1] :]
|
||||
response = self.processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
|
||||
# Yes, it is truly the expected response... Even though the model correctly treats the audio file
|
||||
EXPECTED_RESPONSE = "I'm sorry, but I can't listen to audio. However, if you describe the audio to me,"
|
||||
|
||||
self.assertEqual(response, EXPECTED_RESPONSE)
|
@ -524,6 +524,7 @@ OBJECTS_TO_IGNORE = [
|
||||
"TimeSeriesTransformerConfig",
|
||||
"TokenClassificationPipeline",
|
||||
"TrOCRConfig",
|
||||
"Phi4MultimodalProcessor",
|
||||
"TrainerState",
|
||||
"TrainingArguments",
|
||||
"TrajectoryTransformerConfig",
|
||||
|
@ -89,6 +89,8 @@ PRIVATE_MODELS = [
|
||||
"SmolVLMVisionTransformer",
|
||||
"AriaTextForCausalLM",
|
||||
"AriaTextModel",
|
||||
"Phi4MultimodalAudioModel",
|
||||
"Phi4MultimodalVisionModel",
|
||||
]
|
||||
|
||||
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
||||
|
Loading…
Reference in New Issue
Block a user