mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 19:48:23 +06:00

* squash into single commit * run diff once more * docstring * tests * minor chnages and ready to go * Update src/transformers/models/llava_next_video/processing_llava_next_video.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vipllava/test_modeling_vipllava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * [run-slow] llava-next-video * [run-slow] llava-next-video * [run-slow] llava_next_video * fix two tests * fix slow tests * remove logit checks due to numeric errors * run test once more * [run-slow] llava_next_video * final try to pass the test * [run-slow] llava_next_video * [run-slow] llava_next_video * [run-slow] llava_next_video * style * fix * style --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
260 lines
12 KiB
Markdown
260 lines
12 KiB
Markdown
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||
the License. You may obtain a copy of the License at
|
||
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||
specific language governing permissions and limitations under the License.
|
||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||
rendered properly in your Markdown viewer.
|
||
|
||
-->
|
||
|
||
# LLaVa-NeXT-Video
|
||
|
||
## Overview
|
||
|
||
The LLaVa-NeXT-Video model was proposed in [LLaVA-NeXT: A Strong Zero-shot Video Understanding Model
|
||
](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/) by Yuanhan Zhang, Bo Li, Haotian Liu, Yong Jae Lee, Liangke Gui, Di Fu, Jiashi Feng, Ziwei Liu, Chunyuan Li. LLaVa-NeXT-Video improves upon [LLaVa-NeXT](llava_next) by fine-tuning on a mix if video and image dataset thus increasing the model's performance on videos.
|
||
|
||
[LLaVA-NeXT](llava_next) surprisingly has strong performance in understanding video content in zero-shot fashion with the AnyRes technique that it uses. The AnyRes technique naturally represents a high-resolution image into multiple images. This technique is naturally generalizable to represent videos because videos can be considered as a set of frames (similar to a set of images in LLaVa-NeXT). The current version of LLaVA-NeXT makes use of AnyRes and trains with supervised fine-tuning (SFT) on top of LLaVA-Next on video data to achieves better video understanding capabilities.The model is a current SOTA among open-source models on [VideoMME bench](https://arxiv.org/abs/2405.21075).
|
||
|
||
|
||
The introduction from the blog is the following:
|
||
|
||
On January 30, 2024, we released LLaVA-NeXT, an open-source Large Multimodal Model (LMM) that has been trained exclusively on text-image data. With the proposed AnyRes technique, it boosts capabilities in reasoning, OCR, and world knowledge, demonstrating remarkable performance across a spectrum of image-based multimodal understanding tasks, and even exceeding Gemini-Pro on several image benchmarks, e.g. MMMU and MathVista.
|
||
|
||
**In today’s exploration, we delve into the performance of LLaVA-NeXT within the realm of video understanding tasks. We reveal that LLaVA-NeXT surprisingly has strong performance in understanding video content. The current version of LLaVA-NeXT for videos has several improvements:
|
||
|
||
- Zero-shot video representation capabilities with AnyRes: The AnyRes technique naturally represents a high-resolution image into multiple images that a pre-trained VIT is able to digest, and forms them into a concantenated sequence. This technique is naturally generalizable to represent videos (consisting of multiple frames), allowing the image-only-trained LLaVA-Next model to perform surprisingly well on video tasks. Notably, this is the first time that LMMs show strong zero-shot modality transfer ability.
|
||
- Inference with length generalization improves on longer videos. The linear scaling technique enables length generalization, allowing LLaVA-NeXT to effectively handle long-video beyond the limitation of the "max_token_length" of the LLM.
|
||
- Strong video understanding ability. (1) LLaVA-Next-Image, which combines the above two techniques, yields superior zero-shot performance than open-source LMMs tuned on videos. (2) LLaVA-Next-Video, further supervised fine-tuning (SFT) LLaVA-Next-Image on video data, achieves better video understanding capabilities compared to LLaVA-Next-Image. (3) LLaVA-Next-Video-DPO, which aligns the model response with AI feedback using direct preference optimization (DPO), showing significant performance boost.
|
||
- Efficient deployment and inference with SGLang. It allows 5x faster inference on video tasks, allowing more scalable serving such as million-level video re-captioning. See instructions in our repo.**
|
||
|
||
|
||
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
|
||
The original code can be found [here](https://github.com/LLaVA-VL/LLaVA-NeXT/tree/inference).
|
||
|
||
## Usage tips
|
||
|
||
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
|
||
|
||
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use tokenizer's `apply_chat_template` to format your prompts correctly. Below is an example of how to do that.
|
||
|
||
We will use [LLaVA-NeXT-Video-7B-hf](https://huggingface.co/llava-hf/LLaVA-NeXT-Video-7B-hf) and a conversation history of videos and images. Each content field has to be a list of dicts, as follows:
|
||
|
||
```python
|
||
from transformers import LlavaNextVideoProcessor
|
||
|
||
processor = LlavaNextVideoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||
|
||
conversation = [
|
||
{
|
||
"role": "system",
|
||
"content": [
|
||
{"type": "text", "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."},
|
||
],
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "What’s shown in this image?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
{
|
||
"role": "assistant",
|
||
"content": [{"type": "text", "text": "This image shows a red stop sign."},]
|
||
},
|
||
{
|
||
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "Why is this video funny?"},
|
||
{"type": "video"},
|
||
],
|
||
},
|
||
]
|
||
|
||
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||
|
||
# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your visuals
|
||
print(text_prompt)
|
||
```
|
||
|
||
## Usage example
|
||
|
||
### Single Media Mode
|
||
|
||
The model can accept both images and videos as input. Here's an example code for inference in half-precision (`torch.float16`):
|
||
|
||
```python
|
||
import av
|
||
import torch
|
||
import numpy as np
|
||
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
|
||
|
||
def read_video_pyav(container, indices):
|
||
'''
|
||
Decode the video with PyAV decoder.
|
||
Args:
|
||
container (`av.container.input.InputContainer`): PyAV container.
|
||
indices (`List[int]`): List of frame indices to decode.
|
||
Returns:
|
||
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
||
'''
|
||
frames = []
|
||
container.seek(0)
|
||
start_index = indices[0]
|
||
end_index = indices[-1]
|
||
for i, frame in enumerate(container.decode(video=0)):
|
||
if i > end_index:
|
||
break
|
||
if i >= start_index and i in indices:
|
||
frames.append(frame)
|
||
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
||
|
||
# Load the model in half-precision
|
||
model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", torch_dtype=torch.float16, device_map="auto")
|
||
processor = LlavaNextVideoProcessor.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf")
|
||
|
||
# Load the video as an np.array, sampling uniformly 8 frames (can sample more for longer videos)
|
||
video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
|
||
container = av.open(video_path)
|
||
total_frames = container.streams.video[0].frames
|
||
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
|
||
video = read_video_pyav(container, indices)
|
||
|
||
conversation = [
|
||
{
|
||
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "Why is this video funny?"},
|
||
{"type": "video"},
|
||
],
|
||
},
|
||
]
|
||
|
||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||
inputs = processor(text=prompt, videos=video, return_tensors="pt")
|
||
|
||
out = model.generate(**inputs, max_new_tokens=60)
|
||
processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||
```
|
||
|
||
|
||
### Mixed Media Mode
|
||
|
||
The model can also generate from an interleaved image-video inputs. However note, that it was not trained in interleaved image-video setting which might affect the performance. Below is an example usage for mixed media input, add the following lines to the above code snippet:
|
||
|
||
```python
|
||
from PIL import Image
|
||
import requests
|
||
|
||
# Generate from image and video mixed inputs
|
||
# Load and image and write a new prompt
|
||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||
image = Image.open(requests.get(url, stream=True).raw)
|
||
conversation = [
|
||
{
|
||
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "How many cats are there in the image?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
{
|
||
|
||
"role": "assistant",
|
||
"content": [{"type": "text", "text": "There are two cats"}],
|
||
},
|
||
{
|
||
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "Why is this video funny?"},
|
||
{"type": "video"},
|
||
],
|
||
},
|
||
]
|
||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
||
inputs = processor(text=prompt, images=image, videos=clip, padding=True, return_tensors="pt")
|
||
|
||
# Generate
|
||
generate_ids = model.generate(**inputs, max_length=50)
|
||
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||
|
||
```
|
||
|
||
## Model optimization
|
||
|
||
### Quantization using Bitsandbytes for memory efficiency
|
||
|
||
The model can be loaded in lower bits, significantly reducing memory burden while maintaining the performance of the original model. This allows for efficient deployment on resource-constrained cases.
|
||
|
||
First make sure to install bitsandbytes by running `pip install bitsandbytes` and to have access to a CUDA compatible GPU device. Load the quantized model by simply adding [`BitsAndBytesConfig`](../main_classes/quantization#transformers.BitsAndBytesConfig) as shown below:
|
||
|
||
|
||
```python
|
||
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
|
||
|
||
# specify how to quantize the model
|
||
quantization_config = BitsAndBytesConfig(
|
||
load_in_4bit=True,
|
||
bnb_4bit_quant_type="nf4",
|
||
bnb_4bit_compute_dtype=torch.float16,
|
||
)
|
||
|
||
model = LlavaNextVideoForConditionalGeneration.from_pretrained("llava-hf/LLaVA-NeXT-Video-7B-hf", quantization_config=quantization_config, device_map="auto")
|
||
```
|
||
|
||
|
||
### Flash-Attention 2 to speed-up generation
|
||
|
||
Additionally, we can greatly speed-up model inference by using [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model.
|
||
|
||
First, make sure to install the latest version of Flash Attention 2:
|
||
|
||
```bash
|
||
pip install -U flash-attn --no-build-isolation
|
||
```
|
||
|
||
Also, you should have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of the [flash attention repository](https://github.com/Dao-AILab/flash-attention). FlashAttention-2 can only be used when a model is loaded in `torch.float16` or `torch.bfloat16`.
|
||
|
||
To load and run a model using Flash Attention-2, simply add `attn_implementation="flash_attention_2"` when loading the model as follows:
|
||
|
||
```python
|
||
from transformers import LlavaNextVideoForConditionalGeneration
|
||
|
||
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
|
||
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||
torch_dtype=torch.float16,
|
||
attn_implementation="flash_attention_2",
|
||
).to(0)
|
||
```
|
||
|
||
|
||
|
||
## LlavaNextVideoConfig
|
||
|
||
[[autodoc]] LlavaNextVideoConfig
|
||
|
||
## LlavaNextVideoProcessor
|
||
|
||
[[autodoc]] LlavaNextVideoProcessor
|
||
|
||
## LlavaNextVideoImageProcessor
|
||
|
||
[[autodoc]] LlavaNextVideoImageProcessor
|
||
|
||
## LlavaNextVideoForConditionalGeneration
|
||
|
||
[[autodoc]] LlavaNextVideoForConditionalGeneration
|
||
- forward
|