transformers/docs/source/en/model_doc/phi4_multimodal.md
Cyril Vallez 4303d88c09
Add Phi4 multimodal (#36939)
* raw start

* update

* update

* add to imports

* update

* up

* simplify configs

* clean configs

* style

* typos

* Update convert_phi4_multimodal_weights_to_hf.py

* Update convert_phi4_multimodal_weights_to_hf.py

* fix

* up

* up

* up

* Update convert_phi4_multimodal_weights_to_hf.py

* Update convert_phi4_multimodal_weights_to_hf.py

* up

* up

* up

* Update feature_extraction_phi4_multimodal.py

* up

* up

* up

* up

* up

* simplify configs

* typo

* cut code

* typo

* typo

* typo

* re

* typo

* up

* up

* up

* add tests

* fix

* fix

* Update test_modeling_phi4_multimodal.py

* up

* Update test_modeling_phi4_multimodal.py

* doc

* fix

* up

* up

* up

* up

* up

* up

* simplify

* up

* simplify

* config docstrings

* cleanup

* clean

* typo

* typo

* fix

* Update phi4_multimodal.md

* fix

* fix

* Update test_modeling_phi4_multimodal.py

* update

* simplify reshapes and permutes

* up

* simplify special tokens

* simplify processor a lot

* Update processing_phi4_multimodal.py

* Update processing_phi4_multimodal.py

* switch to fast processor

* image processor

* Update image_processing_phi4_multimodal_fast.py

* add lora extraction to converter

* Update convert_phi4_multimodal_weights_to_hf.py

* Update __init__.py

* add AudioInput type in audio_utils

* rewrite feature_extraction: support torch batched FFT

* input_audio_embeds -> audio_input_features, input_image_embeds -> image_pixel_values

* test update

* not mono channel warning update

* remove auto maps from processor

* kargs dispatch in processor

* simplify kwargs dispatch

* simplify merging

* remove default sampling rate

* style

* Update test_modeling_phi4_multimodal.py

* update doc

* doc

* torch only feature extractor

* make fake tokens adjustable

* Update feature_extraction_phi4_multimodal.py

* fix

* Update processing_phi4_multimodal.py

* simplify mask

* last touch

* fix copies

* style

* Update audio_utils.py

* style

* Update feature_extraction_phi4_multimodal.py

* Update __init__.py

* docstrings

* copies

* fix all checks

* back to fix-copies

* trigger CIs

* Update feature_extraction_phi4_multimodal.py

* improve tests with multimodal inputs

* trigger CIs

---------

Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
2025-03-25 09:55:21 +01:00

5.6 KiB

Phi4 Multimodal

Overview

Phi4 Multimodal is a lightweight open multimodal foundation model that leverages the language, vision, and speech research and datasets used for Phi-3.5 and 4.0 models. The model processes text, image, and audio inputs, generating text outputs, and comes with 128K token context length. The model underwent an enhancement process, incorporating both supervised fine-tuning, direct preference optimization and RLHF (Reinforcement Learning from Human Feedback) to support precise instruction adherence and safety measures. The languages that each modal supports are the following:

  • Text: Arabic, Chinese, Czech, Danish, Dutch, English, Finnish, French, German, Hebrew, Hungarian, Italian, Japanese, Korean, Norwegian, Polish, Portuguese, Russian, Spanish, Swedish, Thai, Turkish, Ukrainian
  • Vision: English
  • Audio: English, Chinese, German, French, Italian, Japanese, Spanish, Portuguese

This model was contributed by Cyril Vallez. The most recent code can be found here.

Usage tips

Phi4-multimodal-instruct can be found on the Huggingface Hub

In the following, we demonstrate how to use it for inference depending on the input modalities (text, image, audio).

import requests
import torch
import os
import io
from PIL import Image
import soundfile as sf
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from urllib.request import urlopen


# Define model path
model_path = "microsoft/Phi-4-multimodal-instruct"
device = "cuda:0"

# Load model and processor
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device,  torch_dtype=torch.float16)

# Optional: load the adapters (note that without them, the base model will very likely not work well)
model.load_adapter(model_path, adapter_name="speech", device_map=device, adapter_kwargs={"subfolder": 'speech-lora'})
model.load_adapter(model_path, adapter_name="vision", device_map=device, adapter_kwargs={"subfolder": 'vision-lora'})

# Define prompt structure
user_prompt = '<|user|>'
assistant_prompt = '<|assistant|>'
prompt_suffix = '<|end|>'

# Part 1: Image Processing
model.set_adapter("vision") # if loaded, activate the vision adapter
print("\n--- IMAGE PROCESSING ---")
image_url = 'https://www.ilankelman.org/stopsigns/australia.jpg'
prompt = f'{user_prompt}<|image_1|>What is shown in this image?{prompt_suffix}{assistant_prompt}'
print(f'>>> Prompt\n{prompt}')

# Download and open image
image = Image.open(requests.get(image_url, stream=True).raw)
inputs = processor(text=prompt, images=image, return_tensors='pt').to(device)

# Generate response
generate_ids = model.generate(
    **inputs,
    max_new_tokens=1000,
    do_sample=False,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(f'>>> Response\n{response}')

# Part 2: Audio Processing
model.set_adapter("speech") # if loaded, activate the speech adapter
print("\n--- AUDIO PROCESSING ---")
audio_url = "https://upload.wikimedia.org/wikipedia/commons/b/b0/Barbara_Sahakian_BBC_Radio4_The_Life_Scientific_29_May_2012_b01j5j24.flac"
speech_prompt = "Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation."
prompt = f'{user_prompt}<|audio_1|>{speech_prompt}{prompt_suffix}{assistant_prompt}'
print(f'>>> Prompt\n{prompt}')

# Downlowd and open audio file
audio, sample_rate = sf.read(io.BytesIO(urlopen(audio_url).read()))

# Process with the model
inputs = processor(text=prompt, audios=audio, sample_rate=sample_rate, return_tensors='pt').to(device)

generate_ids = model.generate(
    **inputs,
    max_new_tokens=1000,
    do_sample=False,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
    generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(f'>>> Response\n{response}')

Phi4MultimodalFeatureExtractor

autodoc Phi4MultimodalFeatureExtractor

Phi4MultimodalImageProcessorFast

autodoc Phi4MultimodalImageProcessorFast

Phi4MultimodalProcessor

autodoc Phi4MultimodalProcessor

Phi4MultimodalAudioConfig

autodoc Phi4MultimodalAudioConfig

Phi4MultimodalVisionConfig

autodoc Phi4MultimodalVisionConfig

Phi4MultimodalConfig

autodoc Phi4MultimodalConfig

Phi4MultimodalAudioModel

autodoc Phi4MultimodalAudioModel

Phi4MultimodalVisionModel

autodoc Phi4MultimodalVisionModel

Phi4MultimodalModel

autodoc Phi4MultimodalModel - forward

Phi4MultimodalForCausalLM

autodoc Phi4MultimodalForCausalLM - forward