Merge branch 'huggingface:main' into add_plm

This commit is contained in:
Jiang Jiwen 2025-05-08 11:54:04 +08:00 committed by GitHub
commit d88b28348c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
194 changed files with 18087 additions and 4169 deletions

View File

@ -29,7 +29,7 @@ jobs:
run_models_gpu:
name: " "
runs-on:
group: aws-g4dn-2xlarge-cache
group: aws-g4dn-4xlarge-cache
container:
image: ${{ inputs.docker }}
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/

View File

@ -28,7 +28,7 @@ jobs:
matrix:
split_keys: ${{ fromJson(inputs.split_keys) }}
runs-on:
group: aws-g4dn-2xlarge-cache
group: aws-g4dn-4xlarge-cache
container:
image: huggingface/transformers-all-latest-gpu
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/

View File

@ -15,7 +15,7 @@ jobs:
setup:
name: Setup
runs-on:
group: aws-g4dn-2xlarge-cache
group: aws-g4dn-4xlarge-cache
container:
image: huggingface/transformers-all-latest-gpu
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/

View File

@ -107,7 +107,7 @@ jobs:
run: |
echo "${{ inputs.machine_type }}"
if [ "${{ inputs.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
if [ "${{ inputs.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
machine_type=single-gpu
elif [ "${{ inputs.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
machine_type=multi-gpu

View File

@ -185,7 +185,7 @@ jobs:
fail-fast: false
matrix:
folders: ${{ fromJson(needs.get-tests.outputs.models) }}
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
@ -239,7 +239,7 @@ jobs:
shell: bash
run: |
echo "${{ matrix.machine_type }}"
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
machine_type=single-gpu
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
machine_type=multi-gpu
@ -292,7 +292,7 @@ jobs:
fail-fast: false
matrix:
folders: ${{ fromJson(needs.get-tests.outputs.quantizations) }}
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
@ -338,7 +338,7 @@ jobs:
shell: bash
run: |
echo "${{ matrix.machine_type }}"
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
machine_type=single-gpu
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
machine_type=multi-gpu

View File

@ -49,7 +49,7 @@ jobs:
name: Setup
strategy:
matrix:
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
@ -107,7 +107,7 @@ jobs:
strategy:
fail-fast: false
matrix:
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
slice_id: ${{ fromJSON(needs.setup.outputs.slice_ids) }}
uses: ./.github/workflows/model_jobs.yml
with:
@ -125,7 +125,7 @@ jobs:
strategy:
fail-fast: false
matrix:
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
slice_id: [0, 1]
uses: ./.github/workflows/model_jobs.yml
with:
@ -143,7 +143,7 @@ jobs:
strategy:
fail-fast: false
matrix:
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
@ -177,7 +177,7 @@ jobs:
run: |
echo "${{ matrix.machine_type }}"
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
machine_type=single-gpu
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
machine_type=multi-gpu
@ -211,7 +211,7 @@ jobs:
strategy:
fail-fast: false
matrix:
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
@ -246,7 +246,7 @@ jobs:
run: |
echo "${{ matrix.machine_type }}"
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
machine_type=single-gpu
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
machine_type=multi-gpu
@ -280,7 +280,7 @@ jobs:
strategy:
fail-fast: false
matrix:
machine_type: [aws-g4dn-2xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
@ -314,7 +314,7 @@ jobs:
run: |
echo "${{ matrix.machine_type }}"
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
machine_type=single-gpu
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
machine_type=multi-gpu
@ -349,7 +349,7 @@ jobs:
strategy:
fail-fast: false
matrix:
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
@ -411,7 +411,7 @@ jobs:
run: |
echo "${{ matrix.machine_type }}"
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
machine_type=single-gpu
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
machine_type=multi-gpu
@ -448,7 +448,7 @@ jobs:
fail-fast: false
matrix:
folders: ${{ fromJson(needs.setup.outputs.quantization_matrix) }}
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
runs-on:
group: '${{ matrix.machine_type }}'
container:
@ -491,7 +491,7 @@ jobs:
run: |
echo "${{ matrix.machine_type }}"
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
machine_type=single-gpu
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
machine_type=multi-gpu

View File

@ -35,7 +35,7 @@ jobs:
shell: bash
run: |
if [[ "${{ github.event.inputs.num_gpus }}" == "single" && "${{ github.event.inputs.runner_type }}" == "t4" ]]; then
echo "RUNNER=aws-g4dn-2xlarge-cache" >> $GITHUB_ENV
echo "RUNNER=aws-g4dn-4xlarge-cache" >> $GITHUB_ENV
elif [[ "${{ github.event.inputs.num_gpus }}" == "multi" && "${{ github.event.inputs.runner_type }}" == "t4" ]]; then
echo "RUNNER=aws-g4dn-12xlarge-cache" >> $GITHUB_ENV
elif [[ "${{ github.event.inputs.num_gpus }}" == "single" && "${{ github.event.inputs.runner_type }}" == "a10" ]]; then

View File

@ -495,6 +495,8 @@
title: Granite
- local: model_doc/granitemoe
title: GraniteMoe
- local: model_doc/granitemoehybrid
title: GraniteMoeHybrid
- local: model_doc/granitemoeshared
title: GraniteMoeShared
- local: model_doc/helium
@ -823,6 +825,8 @@
title: Bark
- local: model_doc/clap
title: CLAP
- local: model_doc/csm
title: CSM
- local: model_doc/dac
title: dac
- local: model_doc/encodec

View File

@ -108,7 +108,7 @@ If in doubt about what args/kwargs a given model sends to the attention function
## Accessing current available implementations
Most of the time, you will simply need to `register` a new function. If, however, you need to access an existing one,
and/or perform a few checks, the prefered way is to use the global `ALL_ATTENTION_FUNCTIONS`. It behaves the same way you
and/or perform a few checks, the preferred way is to use the global `ALL_ATTENTION_FUNCTIONS`. It behaves the same way you
would expect from a usual Python dictionary:
```python

View File

@ -102,6 +102,10 @@ response = processor.decode(output_ids, skip_special_tokens=True)
[[autodoc]] AriaTextModel
## AriaModel
[[autodoc]] AriaModel
## AriaTextForCausalLM
[[autodoc]] AriaTextForCausalLM

View File

@ -237,6 +237,10 @@ for i, output in enumerate(batch_outputs):
[[autodoc]] AyaVisionConfig
## AyaVisionModel
[[autodoc]] AyaVisionModel
## AyaVisionForConditionalGeneration
[[autodoc]] AyaVisionForConditionalGeneration

View File

@ -150,6 +150,11 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] BeitImageProcessor
- preprocess
- post_process_semantic_segmentation
## BeitImageProcessorFast
[[autodoc]] BeitImageProcessorFast
- preprocess
- post_process_semantic_segmentation
<frameworkcontent>
<pt>

View File

@ -0,0 +1,377 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Csm
## Overview
The Conversational Speech Model (CSM) is the first open-source contextual text-to-speech model [released by Sesame](https://www.sesame.com/research/crossing_the_uncanny_valley_of_voice). It is designed to generate natural-sounding speech with or without conversational context. This context typically consists of multi-turn dialogue between speakers, represented as sequences of text and corresponding spoken audio.
**Model Architecture:**
CSM is composed of two LLaMA-style auto-regressive transformer decoders: a backbone decoder that predicts the first codebook token and a depth decoder that generates the remaining tokens. It uses the pretrained codec model [Mimi](./mimi.md), introduced by Kyutai, to encode speech into discrete codebook tokens and decode them back into audio.
The original csm-1b checkpoint is available under the [Sesame](https://huggingface.co/sesame/csm-1b) organization on Hugging Face.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/eustlb/documentation-images/resolve/main/csm_architecture.png"/>
</div>
## Usage Tips
### Without Conversational Context
CSM can be used to simply generate speech from a text prompt:
```python
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
text = "[0]The past is just a story we tell ourselves." # `[0]` for speaker id 0
inputs = processor(text, add_special_tokens=True).to(device)
# another equivalent way to prepare the inputs
conversation = [
{"role": "0", "content": [{"type": "text", "text": "The past is just a story we tell ourselves."}]},
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
# infer the model
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, "example_without_context.wav")
```
### With Conversational Context
CSM can be used to generate speech given a conversation, allowing consistency in the voices and content-aware generation:
```python
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
conversation = []
# 1. context
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
conversation.append(
{
"role": f"{speaker_id}",
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
}
)
# 2. text prompt
conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
# infer the model
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, "example_with_context.wav")
```
### Batched Inference
CSM supports batched inference!
```python
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
# here a batch with two prompts
conversation = [
[
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
{"type": "audio", "path": ds[0]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[1]["text"]},
],
},
],
[
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
],
}
],
]
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
audio = model.generate(**inputs, output_audio=True)
processor.save_audio(audio, [f"speech_batch_idx_{i}.wav" for i in range(len(audio))])
```
### Making The Model Go Brrr
CSM supports full-graph compilation with CUDA graphs!
```python
import torch
import copy
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset
model_id = "eustlb/csm-1b"
device = "cuda"
# set logs to ensure no recompilation and graph breaks
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
# use static cache, enabling automatically torch compile with fullgraph and reduce-overhead
model.generation_config.max_length = 250 # big enough to avoid recompilation
model.generation_config.max_new_tokens = None # would take precedence over max_length
model.generation_config.cache_implementation = "static"
model.depth_decoder.generation_config.cache_implementation = "static"
# generation kwargs
gen_kwargs = {
"do_sample": False,
"depth_decoder_do_sample": False,
"temperature": 1.0,
"depth_decoder_temperature": 1.0,
}
# Define a timing decorator
class TimerContext:
def __init__(self, name="Execution"):
self.name = name
self.start_event = None
self.end_event = None
def __enter__(self):
# Use CUDA events for more accurate GPU timing
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.start_event.record()
return self
def __exit__(self, *args):
self.end_event.record()
torch.cuda.synchronize()
elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
print(f"{self.name} time: {elapsed_time:.4f} seconds")
# prepare the inputs
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
conversation = [
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[0]["text"]},
{"type": "audio", "path": ds[0]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[1]["text"]},
{"type": "audio", "path": ds[1]["audio"]["array"]},
],
},
{
"role": f"{ds[2]['speaker_id']}",
"content": [
{"type": "text", "text": ds[2]["text"]},
],
},
]
padded_inputs_1 = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
print("\n" + "="*50)
print("First generation - compiling and recording CUDA graphs...")
with TimerContext("First generation"):
_ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)
print("\n" + "="*50)
print("Second generation - fast !!!")
with TimerContext("Second generation"):
_ = model.generate(**padded_inputs_1, **gen_kwargs)
print("="*50)
# now with different inputs
conversation = [
{
"role": f"{ds[0]['speaker_id']}",
"content": [
{"type": "text", "text": ds[2]["text"]},
{"type": "audio", "path": ds[2]["audio"]["array"]},
],
},
{
"role": f"{ds[1]['speaker_id']}",
"content": [
{"type": "text", "text": ds[3]["text"]},
{"type": "audio", "path": ds[3]["audio"]["array"]},
],
},
{
"role": f"{ds[2]['speaker_id']}",
"content": [
{"type": "text", "text": ds[4]["text"]},
],
},
]
padded_inputs_2 = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
).to(device)
print("\n" + "="*50)
print("Generation with other inputs!")
with TimerContext("Generation with different inputs"):
_ = model.generate(**padded_inputs_2, **gen_kwargs)
print("="*50)
```
### Training
CSM Transformers integration supports training!
```python
from transformers import CsmForConditionalGeneration, AutoProcessor
from datasets import load_dataset, Audio
model_id = "eustlb/csm-1b"
device = "cuda"
# load the model and the processor
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=device)
model.train()
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
# ensure the audio is 24kHz
ds = ds.cast_column("audio", Audio(sampling_rate=24000))
conversation = []
# context
for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
conversation.append(
{
"role": f"{speaker_id}",
"content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
}
)
inputs = processor.apply_chat_template(
conversation,
tokenize=True,
return_dict=True,
output_labels=True,
).to(device)
out = model(**inputs)
out.loss.backward()
```
This model was contributed by [Eustache Le Bihan](https://huggingface.co/eustlb).
The original code can be found [here](https://github.com/SesameAILabs/csm).
## CsmConfig
[[autodoc]] CsmConfig
## CsmDepthDecoderConfig
[[autodoc]] CsmDepthDecoderConfig
## CsmProcessor
[[autodoc]] CsmProcessor
- __call__
## CsmForConditionalGeneration
[[autodoc]] CsmForConditionalGeneration
- forward
- generate
## CsmDepthDecoderForCausalLM
[[autodoc]] CsmDepthDecoderForCausalLM
## CsmDepthDecoderModel
[[autodoc]] CsmDepthDecoderModel
## CsmBackboneModel
[[autodoc]] CsmBackboneModel

View File

@ -174,6 +174,10 @@ for i, image in enumerate(images['pixel_values']):
[[autodoc]] Emu3TextModel
- forward
## Emu3Model
[[autodoc]] Emu3Model
## Emu3ForCausalLM
[[autodoc]] Emu3ForCausalLM

View File

@ -103,6 +103,10 @@ The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece.
[[autodoc]] FuyuConfig
## FuyuModel
[[autodoc]] FuyuModel
## FuyuForCausalLM
[[autodoc]] FuyuForCausalLM

View File

@ -254,6 +254,10 @@ visualizer("<img>What is shown in this image?")
[[autodoc]] Gemma3TextModel
- forward
## Gemma3Model
[[autodoc]] Gemma3Model
## Gemma3ForCausalLM
[[autodoc]] Gemma3ForCausalLM

View File

@ -277,6 +277,10 @@ alt="drawing" width="600"/>
[[autodoc]] GotOcr2Processor
## GotOcr2Model
[[autodoc]] GotOcr2Model
## GotOcr2ForConditionalGeneration
[[autodoc]] GotOcr2ForConditionalGeneration

View File

@ -0,0 +1,64 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# GraniteMoeHybrid
## Overview
The `GraniteMoeHybrid` model builds on top of `GraniteMoeSharedModel` and `Bamba`. Its decoding layers consist of state space layers or MoE attention layers with shared experts. By default, the attention layers do not use positional encoding.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "ibm-granite/granite-4.0-tiny-preview"
tokenizer = AutoTokenizer.from_pretrained(model_path)
# drop device_map if running on CPU
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
model.eval()
# change input text as desired
prompt = "Write a code to find the maximum value in a list of numbers."
# tokenize the text
input_tokens = tokenizer(prompt, return_tensors="pt")
# generate output tokens
output = model.generate(**input_tokens, max_new_tokens=100)
# decode output tokens into text
output = tokenizer.batch_decode(output)
# loop over the batch to print, in this example the batch size is 1
for i in output:
print(i)
```
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
## GraniteMoeHybridConfig
[[autodoc]] GraniteMoeHybridConfig
## GraniteMoeHybridModel
[[autodoc]] GraniteMoeHybridModel
- forward
## GraniteMoeHybridForCausalLM
[[autodoc]] GraniteMoeHybridForCausalLM
- forward

View File

@ -69,6 +69,10 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
[[autodoc]] InstructBlipQFormerModel
- forward
## InstructBlipModel
[[autodoc]] InstructBlipModel
## InstructBlipForConditionalGeneration
[[autodoc]] InstructBlipForConditionalGeneration

View File

@ -73,6 +73,10 @@ The attributes can be obtained from model config, as `model.config.num_query_tok
[[autodoc]] InstructBlipVideoQFormerModel
- forward
## InstructBlipVideoModel
[[autodoc]] InstructBlipVideoModel
- forward
## InstructBlipVideoForConditionalGeneration
[[autodoc]] InstructBlipVideoForConditionalGeneration

View File

@ -340,6 +340,11 @@ This example showcases how to handle a batch of chat conversations with interlea
[[autodoc]] InternVLVisionModel
- forward
## InternVLModel
[[autodoc]] InternVLModel
- forward
## InternVLForConditionalGeneration
[[autodoc]] InternVLForConditionalGeneration

View File

@ -256,6 +256,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] LlavaProcessor
## LlavaModel
[[autodoc]] LlavaModel
## LlavaForConditionalGeneration
[[autodoc]] LlavaForConditionalGeneration

View File

@ -315,6 +315,10 @@ model = AutoModelForImageTextToText.from_pretrained(
[[autodoc]] LlavaNextProcessor
## LlavaNextModel
[[autodoc]] LlavaNextModel
## LlavaNextForConditionalGeneration
[[autodoc]] LlavaNextForConditionalGeneration

View File

@ -262,6 +262,10 @@ model = LlavaNextVideoForConditionalGeneration.from_pretrained(
[[autodoc]] LlavaNextVideoImageProcessor
## LlavaNextVideoModel
[[autodoc]] LlavaNextVideoModel
## LlavaNextVideoForConditionalGeneration
[[autodoc]] LlavaNextVideoForConditionalGeneration

View File

@ -313,6 +313,10 @@ model = LlavaOnevisionForConditionalGeneration.from_pretrained(
[[autodoc]] LlavaOnevisionVideoProcessor
## LlavaOnevisionModel
[[autodoc]] LlavaOnevisionModel
## LlavaOnevisionForConditionalGeneration
[[autodoc]] LlavaOnevisionForConditionalGeneration

View File

@ -227,6 +227,9 @@ This example also how to use `BitsAndBytes` to load the model in 4bit quantizati
[[autodoc]] Mistral3Config
## Mistral3Model
[[autodoc]] Mistral3Model
## Mistral3ForConditionalGeneration

View File

@ -130,6 +130,10 @@ print(processor.decode(output[0], skip_special_tokens=True))
[[autodoc]] MllamaTextModel
- forward
## MllamaModel
[[autodoc]] MllamaModel
## MllamaForCausalLM
[[autodoc]] MllamaForCausalLM

View File

@ -174,6 +174,10 @@ visualizer("<img> What is in this image?")
[[autodoc]] PaliGemmaProcessor
## PaliGemmaModel
[[autodoc]] PaliGemmaModel
## PaliGemmaForConditionalGeneration
[[autodoc]] PaliGemmaForConditionalGeneration

View File

@ -240,6 +240,10 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2_5_VLProcessor
## Qwen2_5_VLTextModel
[[autodoc]] Qwen2_5_VLTextModel
- forward
## Qwen2_5_VLModel

View File

@ -296,6 +296,11 @@ model = Qwen2VLForConditionalGeneration.from_pretrained(
[[autodoc]] Qwen2VLProcessor
## Qwen2VLTextModel
[[autodoc]] Qwen2VLTextModel
- forward
## Qwen2VLModel
[[autodoc]] Qwen2VLModel

View File

@ -50,6 +50,11 @@ A demo Space for image super-resolution with SwinSR can be found [here](https://
[[autodoc]] Swin2SRImageProcessor
- preprocess
## Swin2SRImageProcessorFast
[[autodoc]] Swin2SRImageProcessorFast
- preprocess
## Swin2SRConfig
[[autodoc]] Swin2SRConfig

View File

@ -215,6 +215,10 @@ model = VideoLlavaForConditionalGeneration.from_pretrained(
[[autodoc]] VideoLlavaProcessor
## VideoLlavaModel
[[autodoc]] VideoLlavaModel
## VideoLlavaForConditionalGeneration
[[autodoc]] VideoLlavaForConditionalGeneration

View File

@ -101,6 +101,10 @@ A chat between a curious human and an artificial intelligence assistant. The ass
[[autodoc]] VipLlavaConfig
## VipLlavaModel
[[autodoc]] VipLlavaModel
## VipLlavaForConditionalGeneration
[[autodoc]] VipLlavaForConditionalGeneration

View File

@ -44,7 +44,7 @@ Place all inputs on the same device as the model.
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer("meta-llama/Llama-3.1-8B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_map="auto", quantization_config=quantization_config)
prompt = "Hello, my llama is cute"
@ -196,7 +196,7 @@ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", device_m
input_text = "Hello, my llama is cute"
inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION)::
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

View File

@ -51,7 +51,7 @@ By default, vLLM serves the native implementation and if it doesn't exist, it fa
```shell
vllm serve Qwen/Qwen2.5-1.5B-Instruct \
--task generate \
--model-impl transformers \
--model-impl transformers
```
Add the `trust-remote-code` parameter to enable loading a remote code model.
@ -60,5 +60,5 @@ Add the `trust-remote-code` parameter to enable loading a remote code model.
vllm serve Qwen/Qwen2.5-1.5B-Instruct \
--task generate \
--model-impl transformers \
--trust-remote-code \
--trust-remote-code
```

View File

@ -19,7 +19,7 @@ Questa guida si concentra su come addestrare in maniera efficiente grandi modell
## Mixed precision con IPEX
IPEX è ottimizzato per CPU con AVX-512 o superiore, e funziona per le CPU con solo AVX2. Pertanto, si prevede che le prestazioni saranno più vantaggiose per le le CPU Intel con AVX-512 o superiori, mentre le CPU con solo AVX2 (ad esempio, le CPU AMD o le CPU Intel più vecchie) potrebbero ottenere prestazioni migliori con IPEX, ma non sono garantite. IPEX offre ottimizzazioni delle prestazioni per l'addestramento della CPU sia con Float32 che con BFloat16. L'uso di BFloat16 è l'argomento principale delle seguenti sezioni.
IPEX è ottimizzato per CPU con AVX-512 o superiore, e funziona per le CPU con solo AVX2. Pertanto, si prevede che le prestazioni saranno più vantaggiose per le CPU Intel con AVX-512 o superiori, mentre le CPU con solo AVX2 (ad esempio, le CPU AMD o le CPU Intel più vecchie) potrebbero ottenere prestazioni migliori con IPEX, ma non sono garantite. IPEX offre ottimizzazioni delle prestazioni per l'addestramento della CPU sia con Float32 che con BFloat16. L'uso di BFloat16 è l'argomento principale delle seguenti sezioni.
Il tipo di dati a bassa precisione BFloat16 è stato supportato in modo nativo su 3rd Generation Xeon® Scalable Processors (aka Cooper Lake) con AVX512 e sarà supportata dalla prossima generazione di Intel® Xeon® Scalable Processors con Intel® Advanced Matrix Extensions (Intel® AMX) instruction set con prestazioni ulteriormente migliorate. L'Auto Mixed Precision per il backende della CPU è stato abilitato da PyTorch-1.10. allo stesso tempo, il supporto di Auto Mixed Precision con BFloat16 per CPU e l'ottimizzazione degli operatori BFloat16 è stata abilitata in modo massiccio in Intel® Extension per PyTorch, and parzialmente aggiornato al branch master di PyTorch. Gli utenti possono ottenere prestazioni migliori ed users experience con IPEX Auto Mixed Precision..

View File

@ -105,6 +105,11 @@ BEiT の使用を開始するのに役立つ公式 Hugging Face およびコミ
[[autodoc]] BeitImageProcessor
- preprocess
## BeitImageProcessorFast
[[autodoc]] BeitImageProcessorFast
- preprocess
- post_process_semantic_segmentation
## BeitModel

View File

@ -98,7 +98,7 @@
- local: generation_strategies
title: 텍스트 생성 전략 사용자 정의
- local: serving
title: 모델 서빙하기
title: 모델 서빙하기
title: 생성
- isExpanded: false
sections:

View File

@ -466,7 +466,12 @@ setup(
package_data={"": ["**/*.cu", "**/*.cpp", "**/*.cuh", "**/*.h", "**/*.pyx", "py.typed"]},
zip_safe=False,
extras_require=extras,
entry_points={"console_scripts": ["transformers=transformers.commands.transformers_cli:main", "transformers-cli=transformers.commands.transformers_cli:main_cli"]},
entry_points={
"console_scripts": [
"transformers=transformers.commands.transformers_cli:main",
"transformers-cli=transformers.commands.transformers_cli:main_cli",
]
},
python_requires=">=3.9.0",
install_requires=list(install_requires),
classifiers=[

View File

@ -24,7 +24,12 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import requests
from .utils import is_librosa_available, requires_backends
from .utils import (
is_librosa_available,
is_numpy_array,
is_torch_tensor,
requires_backends,
)
if is_librosa_available():
@ -69,6 +74,36 @@ AudioInput = Union[
]
def is_valid_audio(audio):
return is_numpy_array(audio) or is_torch_tensor(audio)
def is_valid_list_of_audio(audio):
return audio and all(is_valid_audio(audio_i) for audio_i in audio)
def make_list_of_audio(
audio: Union[list[AudioInput], AudioInput],
) -> AudioInput:
"""
Ensure that the output is a list of audio.
Args:
audio (`Union[List[AudioInput], AudioInput]`):
The input audio.
Returns:
list: A list of audio.
"""
# If it's a list of audios, it's already in the right format
if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
return audio
# If it's a single audio, convert it to a list of
if is_valid_audio(audio):
return [audio]
raise ValueError("Invalid input type. Must be a single audio or a list of audio")
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
"""
Convert frequency from hertz to mels.

View File

@ -73,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
cur_len = input_ids.shape[-1]
cur_len = input_ids.shape[1]
is_done = cur_len >= self.max_length
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings:
logger.warning_once(

View File

@ -563,7 +563,7 @@ class GenerationMixin:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
else:
batch_size, sequence_length = model_inputs[input_ids_key].shape
batch_size, sequence_length = model_inputs[input_ids_key].shape[:2]
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
@ -968,7 +968,7 @@ class GenerationMixin:
atm_translator = AssistantVocabTranslatorCache.get_translator(
target_tokenizer,
assistant_tokenizer,
self.config.vocab_size,
self.config.get_text_config().vocab_size,
assistant_model=assistant_model,
assistant_prune_lm_head=True, # prune LM head of assistant model
)
@ -1234,7 +1234,9 @@ class GenerationMixin:
# Watermarking should be after all logits processing is finished (see #34630)
if generation_config.watermarking_config is not None:
processors.append(
generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
generation_config.watermarking_config.construct_processor(
self.config.get_text_config().vocab_size, device
)
)
# `LogitNormalization` should always be the last logit processor, when present
@ -1412,7 +1414,7 @@ class GenerationMixin:
# 3. Optionally normalize the logits (across the vocab dimension)
if normalize_logits:
scores = scores.reshape(-1, self.config.vocab_size, scores.shape[-1])
scores = scores.reshape(-1, self.config.get_text_config().vocab_size, scores.shape[-1])
scores = torch.nn.functional.log_softmax(scores, dim=1)
scores = scores.reshape(-1, scores.shape[-1])
@ -1426,7 +1428,7 @@ class GenerationMixin:
beam_indices[beam_indices_mask] = 0
# 6. multiply beam_indices with vocab size to gather correctly from scores
beam_sequence_indices = beam_indices * self.config.vocab_size
beam_sequence_indices = beam_indices * self.config.get_text_config().vocab_size
# 7. Define which indices contributed to scores
cut_idx = sequences.shape[-1] - max_beam_length
@ -1706,7 +1708,7 @@ class GenerationMixin:
return generation_config, model_kwargs
def _get_initial_cache_position(self, input_ids, model_kwargs):
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
@ -1716,7 +1718,7 @@ class GenerationMixin:
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
)
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
cache_position = torch.ones(seq_length, dtype=torch.int64, device=device).cumsum(0) - 1
past_length = 0
if model_kwargs.get("past_key_values") is not None:
@ -2330,7 +2332,7 @@ class GenerationMixin:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
input_ids_length = input_ids.shape[1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
@ -2803,9 +2805,9 @@ class GenerationMixin:
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
batch_size, cur_length = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
this_peer_finished = False
@ -3014,9 +3016,9 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
batch_size, cur_len = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# Create cosine_matrix_mask based on the attention_mask
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
@ -3426,10 +3428,10 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
@ -3832,7 +3834,7 @@ class GenerationMixin:
num_beams = generation_config.num_beams
num_return_sequences = generation_config.num_return_sequences
batch_size_unflattened, cur_len = input_ids.shape
batch_size_unflattened, cur_len = input_ids.shape[:2]
batch_size = batch_size_unflattened // num_beams
# TODO (joao): standardize special cases
if self.__class__.__name__ == "MoshiDepthDecoder":
@ -3855,7 +3857,7 @@ class GenerationMixin:
dim=0,
).to(input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
# are newer low-memory alternatives like the offloaded cache)
@ -4154,7 +4156,7 @@ class GenerationMixin:
device = input_ids.device
batch_beam_size, cur_len = input_ids.shape
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
@ -4188,7 +4190,7 @@ class GenerationMixin:
this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
@ -4442,8 +4444,8 @@ class GenerationMixin:
batch_size = len(constrained_beam_scorer._beam_hyps)
num_beams = constrained_beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
batch_beam_size, cur_len = input_ids.shape[:2]
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
if num_beams * batch_size != batch_beam_size:
raise ValueError(
@ -4475,7 +4477,7 @@ class GenerationMixin:
this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@ -4696,14 +4698,14 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
batch_size, cur_len = input_ids.shape[:2]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
this_peer_finished = False
is_first_iteration = True # to preserve the same API in the output as other generation methods
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]
cur_len = input_ids.shape[1]
# 1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
@ -4793,7 +4795,7 @@ class GenerationMixin:
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1]
new_cur_len = input_ids.shape[1]
# 4.2. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1

View File

@ -158,4 +158,5 @@ LOSS_MAPPING = {
"RTDetrForObjectDetection": RTDetrForObjectDetectionLoss,
"RTDetrV2ForObjectDetection": RTDetrForObjectDetectionLoss,
"DFineForObjectDetection": DFineForObjectDetectionLoss,
"CsmForConditionalGeneration": ForCausalLMLoss,
}

View File

@ -114,6 +114,7 @@ from .utils import (
is_torch_npu_available,
is_torch_sdpa_available,
is_torch_xla_available,
is_torch_xpu_available,
logging,
replace_return_docstrings,
strtobool,
@ -215,6 +216,28 @@ TORCH_INIT_FUNCTIONS = {
"kaiming_normal": nn.init.kaiming_normal,
}
# DO NOT MODIFY, KEPT FOR BC ONLY
VLMS = [
"aria",
"aya_vision",
"emu3",
"fuyu",
"got_ocr2",
"gemma3",
"internvl",
"llava",
"llava_next",
"llava_next_video",
"llava_onevision",
"mistral3",
"mllama",
"paligemma",
"qwen2_vl",
"qwem2_5_vl",
"video_llava",
"vipllava",
]
@contextmanager
def no_init_weights():
@ -1274,7 +1297,7 @@ def _get_device_map(
)
if device_map != "sequential":
max_memory = get_balanced_memory(
inferred_max_memory = get_balanced_memory(
model,
dtype=target_dtype,
low_zero=(device_map == "balanced_low_0"),
@ -1282,10 +1305,23 @@ def _get_device_map(
**device_map_kwargs,
)
else:
max_memory = get_max_memory(max_memory)
inferred_max_memory = get_max_memory(max_memory)
if hf_quantizer is not None:
max_memory = hf_quantizer.adjust_max_memory(max_memory)
device_map_kwargs["max_memory"] = max_memory
inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
# `inferred_max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU,
# which we can use to allocate parameters.
for device_name in inferred_max_memory.keys():
if isinstance(device_name, int): # it's a GPU device
if is_torch_xpu_available():
unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
else:
unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
inferred_max_memory[device_name] += unused_memory
# respect the `max_memory` passed by the user
if max_memory is not None and device_name in max_memory:
inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name])
device_map_kwargs["max_memory"] = inferred_max_memory
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
@ -1764,6 +1800,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
main_input_name = "input_ids"
model_tags = None
_checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
@ -3470,6 +3508,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
module_map[name + f".{key}"] = module
state_dict = model_to_save.state_dict()
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
original_state_dict = {}
for key, value in state_dict.items():
for pattern, replacement in reverse_key_mapping.items():
replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*?\)", "", pattern)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
# Translate state_dict from smp to hf if saving with smp >= 1.10
if IS_SAGEMAKER_MP_POST_1_10:
for smp_to_hf, _ in smp.state.module_manager.translate_functions:
@ -4057,7 +4110,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
key_mapping = kwargs.pop("key_mapping", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
key_mapping = kwargs.pop("key_mapping", cls._checkpoint_conversion_mapping)
else:
key_mapping = kwargs.pop("key_mapping", None)
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("trust_remote_code", None)
@ -5979,6 +6038,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
# If there is *unused* reserved cuda memory, we can skip/reduce the allocation.
unused_memory = torch.cuda.memory_reserved(index) - torch.cuda.memory_allocated(index)
byte_count = max(0, byte_count - unused_memory)
# Allocate memory
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)

View File

@ -68,6 +68,7 @@ if TYPE_CHECKING:
from .convnextv2 import *
from .cpm import *
from .cpmant import *
from .csm import *
from .ctrl import *
from .cvt import *
from .d_fine import *
@ -129,6 +130,7 @@ if TYPE_CHECKING:
from .granite import *
from .granite_speech import *
from .granitemoe import *
from .granitemoehybrid import *
from .granitemoeshared import *
from .grounding_dino import *
from .groupvit import *

View File

@ -18,12 +18,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
@ -71,23 +70,6 @@ def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> Li
return patches
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class AriaImageProcessor(BaseImageProcessor):
"""
A vision processor for the Aria model that handles image preprocessing.
@ -375,7 +357,7 @@ class AriaImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -389,12 +371,12 @@ class AriaImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -42,7 +42,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_available
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_aria import AriaConfig, AriaTextConfig
@ -58,7 +58,9 @@ if is_torch_flex_attn_available():
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "AriaTextConfig"
_CONFIG_FOR_DOC = "AriaConfig"
@use_kernel_forward_from_hub("RMSNorm")
@ -659,7 +661,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
config_class = AriaConfig
config_class = AriaTextConfig
base_model_prefix = "model"
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
@ -706,8 +708,8 @@ ARIA_TEXT_START_DOCSTRING = r"""
ARIA_TEXT_START_DOCSTRING,
)
class AriaPreTrainedModel(PreTrainedModel):
config_class = AriaTextConfig
base_model_prefix = "model"
config_class = AriaConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["AriaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
@ -1097,6 +1099,9 @@ class AriaTextModel(AriaTextPreTrainedModel):
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
_CONFIG_FOR_TEXT_DOC = "AriaTextConfig"
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
"""
Aria model for causal language modeling tasks.
@ -1112,7 +1117,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = AriaTextConfig
def __init__(self, config: AriaTextConfig):
super().__init__(config)
@ -1141,9 +1145,8 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_TEXT_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -1255,7 +1258,7 @@ class AriaCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -1267,6 +1270,39 @@ class AriaCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class AriaModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Aria outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
ARIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`, *optional*):
@ -1320,30 +1356,131 @@ ARIA_START_DOCSTRING = r"""
@add_start_docstrings(
"""Aria model for conditional generation tasks.
This model combines a vision tower, a multi-modal projector, and a language model
to perform tasks that involve both image and text inputs.""",
"""The Aria model which consists of a vision backbone and a language model, without a language modeling head.""",
ARIA_START_DOCSTRING,
)
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_flex_attn = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
class AriaModel(AriaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: AriaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AriaProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: torch.FloatTensor = None,
vision_feature_layer: int = -1,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
pixel_mask (`torch.FloatTensor]`, *optional*):
The tensors corresponding to the input image mask.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
)
image_attn_mask = None
if patch_attention_mask is not None:
flattened_mask = patch_attention_mask.flatten(1)
image_attn_mask = torch.logical_not(flattened_mask)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, AriaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and inputs_embeds.shape[1] != 1:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
image_embeds = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
image_features = self.get_image_features(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
vision_feature_layer=self.config.vision_feature_layer,
)
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
n_image_features = n_images * n_features_per_image
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = AriaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
def _create_patch_attention_mask(self, pixel_mask):
if pixel_mask is None:
return None
@ -1360,51 +1497,61 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
@add_start_docstrings(
"""Aria model for conditional generation tasks.
This model combines a vision tower, a multi-modal projector, and a language model
to perform tasks that involve both image and text inputs.""",
ARIA_START_DOCSTRING,
)
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: AriaConfig):
super().__init__(config)
self.model = AriaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
def get_decoder(self):
return self.language_model.get_decoder()
@property
def vision_tower(self):
return self.model.vision_tower
def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
vision_feature_layer: int = -1,
):
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
)
image_attn_mask = None
if patch_attention_mask is not None:
flattened_mask = patch_attention_mask.flatten(1)
image_attn_mask = torch.logical_not(flattened_mask)
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -1413,10 +1560,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> AriaCausalLMOutputWithPast:
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1482,37 +1630,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and inputs_embeds.shape[1] != 1:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
image_embeds = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
image_features = self.get_image_features(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
vision_feature_layer=self.config.vision_feature_layer,
)
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
n_image_features = n_images * n_features_per_image
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -1520,11 +1643,14 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
logits_to_keep=logits_to_keep,
return_dict=return_dict,
cache_position=cache_position,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@ -1552,7 +1678,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
logits_to_keep=None,
**kwargs,
):
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1570,11 +1696,67 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [
"AriaForConditionalGeneration",
"AriaPreTrainedModel",
"AriaTextPreTrainedModel",
"AriaTextModel",
"AriaModel",
"AriaTextForCausalLM",
]

View File

@ -12,15 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...activations import ACT2FN
from ...configuration_utils import PretrainedConfig
from ...generation import GenerationMixin
from ...image_processing_utils import BaseImageProcessor, BatchFeature, select_best_resolution
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
from ...image_utils import (
ChannelDimension,
@ -50,7 +48,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.import_utils import is_torch_available
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
LlamaDecoderLayer,
@ -60,7 +58,12 @@ from ..llama.modeling_llama import (
LlamaPreTrainedModel,
LlamaRMSNorm,
)
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast
from ..llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaModelOutputWithPast,
)
from ..llava_next.image_processing_llava_next import divide_to_patches
@ -71,6 +74,11 @@ if is_torch_available():
from torch import nn
_CONFIG_FOR_DOC = "AriaConfig"
_CONFIG_FOR_TEXT_DOC = "AriaTextConfig"
ARIA_TEXT_INPUTS_DOCSTRING = None
def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
"""
Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
@ -461,23 +469,6 @@ class AriaProjector(nn.Module):
return out
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class AriaImageProcessor(BaseImageProcessor):
"""
A vision processor for the Aria model that handles image preprocessing.
@ -765,7 +756,7 @@ class AriaImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -779,12 +770,12 @@ class AriaImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image
@ -1241,7 +1232,7 @@ class AriaTextPreTrainedModel(PreTrainedModel):
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
config_class = AriaConfig
config_class = AriaTextConfig
base_model_prefix = "model"
_no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
supports_gradient_checkpointing = True
@ -1267,6 +1258,8 @@ class AriaTextPreTrainedModel(PreTrainedModel):
class AriaPreTrainedModel(LlamaPreTrainedModel):
config_class = AriaConfig
base_model_prefix = ""
_supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = False
@ -1310,7 +1303,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
"""
_tied_weights_keys = ["lm_head.weight"]
config_class = AriaTextConfig
def __init__(self, config: AriaTextConfig):
super().__init__(config)
@ -1321,11 +1313,20 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_TEXT_DOC)
def forward(self, **super_kwargs):
super().forward(self, **super_kwargs)
class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class AriaModelOutputWithPast(LlavaModelOutputWithPast):
pass
ARIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`, *optional*):
@ -1378,30 +1379,10 @@ ARIA_START_DOCSTRING = r"""
"""
@add_start_docstrings(
"""Aria model for conditional generation tasks.
This model combines a vision tower, a multi-modal projector, and a language model
to perform tasks that involve both image and text inputs.""",
ARIA_START_DOCSTRING,
)
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_flex_attn = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
class AriaModel(LlavaModel):
def __init__(self, config: AriaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AriaProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self.post_init()
def _create_patch_attention_mask(self, pixel_mask):
if pixel_mask is None:
@ -1419,30 +1400,27 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
pixel_mask: torch.FloatTensor = None,
vision_feature_layer: int = -1,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
pixel_mask (`torch.FloatTensor]`, *optional*):
The tensors corresponding to the input image mask.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
@ -1456,14 +1434,94 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_mask: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, AriaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and inputs_embeds.shape[1] != 1:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
image_embeds = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
image_features = self.get_image_features(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
vision_feature_layer=self.config.vision_feature_layer,
)
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
n_image_features = n_images * n_features_per_image
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = AriaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""Aria model for conditional generation tasks.
This model combines a vision tower, a multi-modal projector, and a language model
to perform tasks that involve both image and text inputs.""",
ARIA_START_DOCSTRING,
)
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_mask: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -1472,10 +1530,11 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> AriaCausalLMOutputWithPast:
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1541,37 +1600,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
if pixel_values is not None and inputs_embeds.shape[1] != 1:
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
image_embeds = input_ids == self.config.image_token_id
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
image_features = self.get_image_features(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
vision_feature_layer=self.config.vision_feature_layer,
)
n_images, n_features_per_image = image_features.shape[0], image_features.shape[1]
n_image_features = n_images * n_features_per_image
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_mask=pixel_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -1579,11 +1613,14 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
logits_to_keep=logits_to_keep,
return_dict=return_dict,
cache_position=cache_position,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
@ -1611,7 +1648,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
logits_to_keep=None,
**kwargs,
):
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1639,5 +1676,6 @@ __all__ = [
"AriaPreTrainedModel",
"AriaTextPreTrainedModel",
"AriaTextModel",
"AriaModel",
"AriaTextForCausalLM",
]

View File

@ -80,6 +80,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextConfig"),
("convnextv2", "ConvNextV2Config"),
("cpmant", "CpmAntConfig"),
("csm", "CsmConfig"),
("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"),
("d_fine", "DFineConfig"),
@ -146,6 +147,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("granite", "GraniteConfig"),
("granite_speech", "GraniteSpeechConfig"),
("granitemoe", "GraniteMoeConfig"),
("granitemoehybrid", "GraniteMoeHybridConfig"),
("granitemoeshared", "GraniteMoeSharedConfig"),
("granitevision", "LlavaNextConfig"),
("graphormer", "GraphormerConfig"),
@ -437,6 +439,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("convnextv2", "ConvNeXTV2"),
("cpm", "CPM"),
("cpmant", "CPM-Ant"),
("csm", "CSM"),
("ctrl", "CTRL"),
("cvt", "CvT"),
("d_fine", "D-FINE"),
@ -510,6 +513,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("granite", "Granite"),
("granite_speech", "GraniteSpeech"),
("granitemoe", "GraniteMoeMoe"),
("granitemoehybrid", "GraniteMoeHybrid"),
("granitemoeshared", "GraniteMoeSharedMoe"),
("granitevision", "LLaVA-NeXT"),
("graphormer", "Graphormer"),

View File

@ -57,8 +57,8 @@ else:
IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
("aria", ("AriaImageProcessor",)),
("beit", ("BeitImageProcessor",)),
("aria", ("AriaImageProcessor")),
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
("bit", ("BitImageProcessor", "BitImageProcessorFast")),
("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
("blip-2", ("BlipImageProcessor", "BlipImageProcessorFast")),
@ -71,7 +71,7 @@ else:
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("cvt", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("data2vec-vision", ("BeitImageProcessor",)),
("data2vec-vision", ("BeitImageProcessor", "BeitImageProcessorFast")),
("deformable_detr", ("DeformableDetrImageProcessor", "DeformableDetrImageProcessorFast")),
("deit", ("DeiTImageProcessor", "DeiTImageProcessorFast")),
("depth_anything", ("DPTImageProcessor",)),
@ -150,7 +150,7 @@ else:
("superglue", ("SuperGlueImageProcessor",)),
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin2sr", ("Swin2SRImageProcessor",)),
("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")),
("swinv2", ("ViTImageProcessor", "ViTImageProcessorFast")),
("table-transformer", ("DetrImageProcessor",)),
("timesformer", ("VideoMAEImageProcessor",)),

View File

@ -35,10 +35,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
("albert", "AlbertModel"),
("align", "AlignModel"),
("altclip", "AltCLIPModel"),
("aria", "AriaForConditionalGeneration"),
("aria", "AriaModel"),
("aria_text", "AriaTextModel"),
("audio-spectrogram-transformer", "ASTModel"),
("autoformer", "AutoformerModel"),
("aya_vision", "AyaVisionModel"),
("bamba", "BambaModel"),
("bark", "BarkModel"),
("bart", "BartModel"),
@ -78,6 +79,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("convnext", "ConvNextModel"),
("convnextv2", "ConvNextV2Model"),
("cpmant", "CpmAntModel"),
("csm", "CsmForConditionalGeneration"),
("ctrl", "CTRLModel"),
("cvt", "CvtModel"),
("d_fine", "DFineModel"),
@ -107,6 +109,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("efficientformer", "EfficientFormerModel"),
("efficientnet", "EfficientNetModel"),
("electra", "ElectraModel"),
("emu3", "Emu3Model"),
("encodec", "EncodecModel"),
("ernie", "ErnieModel"),
("ernie_m", "ErnieMModel"),
@ -120,14 +123,16 @@ MODEL_MAPPING_NAMES = OrderedDict(
("focalnet", "FocalNetModel"),
("fsmt", "FSMTModel"),
("funnel", ("FunnelModel", "FunnelBaseModel")),
("fuyu", "FuyuModel"),
("gemma", "GemmaModel"),
("gemma2", "Gemma2Model"),
("gemma3", "Gemma3Model"),
("gemma3_text", "Gemma3TextModel"),
("git", "GitModel"),
("glm", "GlmModel"),
("glm4", "Glm4Model"),
("glpn", "GLPNModel"),
("got_ocr2", "GotOcr2ForConditionalGeneration"),
("got_ocr2", "GotOcr2Model"),
("gpt-sw3", "GPT2Model"),
("gpt2", "GPT2Model"),
("gpt_bigcode", "GPTBigCodeModel"),
@ -138,6 +143,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("granite", "GraniteModel"),
("granitemoe", "GraniteMoeModel"),
("granitemoehybrid", "GraniteMoeHybridModel"),
("granitemoeshared", "GraniteMoeSharedModel"),
("graphormer", "GraphormerModel"),
("grounding-dino", "GroundingDinoModel"),
@ -154,6 +160,9 @@ MODEL_MAPPING_NAMES = OrderedDict(
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("instructblip", "InstructBlipModel"),
("instructblipvideo", "InstructBlipVideoModel"),
("internvl", "InternVLModel"),
("internvl_vision", "InternVLVisionModel"),
("jamba", "JambaModel"),
("janus", "JanusModel"),
@ -168,6 +177,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("lilt", "LiltModel"),
("llama", "LlamaModel"),
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaModel"),
("llava_next", "LlavaNextModel"),
("llava_next_video", "LlavaNextVideoModel"),
("llava_onevision", "LlavaOnevisionModel"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
@ -187,8 +200,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("mgp-str", "MgpstrForSceneTextRecognition"),
("mimi", "MimiModel"),
("mistral", "MistralModel"),
("mistral3", "Mistral3Model"),
("mixtral", "MixtralModel"),
("mlcd", "MLCDVisionModel"),
("mllama", "MllamaModel"),
("mobilebert", "MobileBertModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@ -219,6 +234,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("opt", "OPTModel"),
("owlv2", "Owlv2Model"),
("owlvit", "OwlViTModel"),
("paligemma", "PaliGemmaModel"),
("patchtsmixer", "PatchTSMixerModel"),
("patchtst", "PatchTSTModel"),
("pegasus", "PegasusModel"),
@ -239,11 +255,11 @@ MODEL_MAPPING_NAMES = OrderedDict(
("qdqbert", "QDQBertModel"),
("qwen2", "Qwen2Model"),
("qwen2_5_vl", "Qwen2_5_VLModel"),
("qwen2_5_vl_text", "Qwen2_5_VLModel"),
("qwen2_5_vl_text", "Qwen2_5_VLTextModel"),
("qwen2_audio_encoder", "Qwen2AudioEncoder"),
("qwen2_moe", "Qwen2MoeModel"),
("qwen2_vl", "Qwen2VLModel"),
("qwen2_vl_text", "Qwen2VLModel"),
("qwen2_vl_text", "Qwen2VLTextModel"),
("qwen3", "Qwen3Model"),
("qwen3_moe", "Qwen3MoeModel"),
("recurrent_gemma", "RecurrentGemmaModel"),
@ -305,8 +321,10 @@ MODEL_MAPPING_NAMES = OrderedDict(
("unispeech-sat", "UniSpeechSatModel"),
("univnet", "UnivNetModel"),
("van", "VanModel"),
("video_llava", "VideoLlavaModel"),
("videomae", "VideoMAEModel"),
("vilt", "ViltModel"),
("vipllava", "VipLlavaModel"),
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
("visual_bert", "VisualBertModel"),
("vit", "ViTModel"),
@ -559,6 +577,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gptj", "GPTJForCausalLM"),
("granite", "GraniteForCausalLM"),
("granitemoe", "GraniteMoeForCausalLM"),
("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
("granitemoeshared", "GraniteMoeSharedForCausalLM"),
("helium", "HeliumForCausalLM"),
("jamba", "JambaForCausalLM"),
@ -878,6 +897,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mllama", "MllamaForConditionalGeneration"),
@ -1448,6 +1468,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
[
# Model for Text-To-Waveform mapping
("bark", "BarkModel"),
("csm", "CsmForConditionalGeneration"),
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
("musicgen", "MusicgenForConditionalGeneration"),
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),

View File

@ -27,15 +27,16 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_aya_vision import AyaVisionConfig
@ -115,9 +116,8 @@ AYA_VISION_START_DOCSTRING = r"""
)
class AyaVisionPreTrainedModel(PreTrainedModel):
config_class = AyaVisionConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["AyaVisionVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -169,7 +169,7 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -181,7 +181,40 @@ class AyaVisionCausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
AYA_VISION_INPUTS_DOCSTRING = """
@dataclass
class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for AyaVision outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
AYA_VISION_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@ -193,8 +226,8 @@ AYA_VISION_INPUTS_DOCSTRING = """
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`AyaVisionProcessor`] uses
[`GotOcr2ImageProcessor`] for processing images.
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`AyaVisionProcessor`] uses
[`CLIPImageProcessor`] for processing images).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -259,23 +292,18 @@ AYA_VISION_INPUTS_DOCSTRING = """
@add_start_docstrings(
"""The AyaVision model which consists of a vision backbone and a language model.""",
"""The AyaVision model which consists of a vision backbone and a language model, without a language modeling head.""",
AYA_VISION_START_DOCSTRING,
)
class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
class AyaVisionModel(AyaVisionPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: AyaVisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -284,18 +312,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -307,7 +323,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
@ -342,6 +358,140 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, AyaVisionModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = AyaVisionModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The AyaVision model which consists of a vision backbone and a language model.""",
AYA_VISION_START_DOCSTRING,
)
class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: AyaVisionConfig):
super().__init__(config)
self.model = AyaVisionModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(AYA_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AyaVisionCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@ -410,7 +560,6 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
>>> gen_tokens = model.generate(**inputs, max_new_tokens=300, do_sample=True, temperature=0.3)
>>> processor.tokenizer.decode(gen_tokens[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -425,73 +574,32 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return AyaVisionCausalLMOutputWithPast(
loss=loss,
@ -499,7 +607,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -515,7 +623,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -532,15 +640,60 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
return model_inputs
def tie_weights(self):
return self.language_model.tie_weights()
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"]
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]

View File

@ -22,6 +22,7 @@ from torch import nn
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaPreTrainedModel,
)
@ -88,21 +89,8 @@ class AyaVisionMultiModalProjector(nn.Module):
return image_features
AYA_VISION_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`AyaVisionConfig`] or [`AyaVisionVisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
AYA_VISION_START_DOCSTRING = None
AYA_VISION_INPUTS_DOCSTRING = None
@add_start_docstrings(
@ -133,81 +121,8 @@ class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
AYA_VISION_INPUTS_DOCSTRING = """
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`GotOcr2ImageProcessor.__call__`] for details. [`AyaVisionProcessor`] uses
[`GotOcr2ImageProcessor`] for processing images.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
vision_feature_layer (`Union[int, List[int]], *optional*, defaults to -2`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
class AyaVisionModel(LlavaModel):
pass
@add_start_docstrings(
@ -215,16 +130,6 @@ AYA_VISION_INPUTS_DOCSTRING = """
AYA_VISION_START_DOCSTRING,
)
class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
def tie_weights(self):
return self.language_model.tie_weights()
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -312,4 +217,4 @@ class AyaVisionForConditionalGeneration(LlavaForConditionalGeneration):
)
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"]
__all__ = ["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel", "AyaVisionModel"]

View File

@ -854,6 +854,7 @@ class BambaMixer(nn.Module):
# Init cache
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
cache_params.has_previous_state = True
scan_output = self.norm(y, gate)

View File

@ -651,6 +651,7 @@ class BambaMixer(nn.Module):
# Init cache
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
cache_params.has_previous_state = True
scan_output = self.norm(y, gate)

View File

@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .configuration_beit import *
from .feature_extraction_beit import *
from .image_processing_beit import *
from .image_processing_beit_fast import *
from .modeling_beit import *
from .modeling_flax_beit import *
else:

View File

@ -0,0 +1,284 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Beit."""
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torchvision.transforms import functional as F
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
is_torch_tensor,
make_list_of_images,
pil_torch_interpolation_mapping,
validate_kwargs,
)
from ...processing_utils import Unpack
from ...utils import TensorType, add_start_docstrings
from ...utils.deprecation import deprecate_kwarg
class BeitFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
do_reduce_labels: Optional[bool]
@add_start_docstrings(
"Constructs a fast Beit image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
""",
)
class BeitImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 224, "width": 224}
default_to_square = True
crop_size = {"height": 224, "width": 224}
do_resize = True
do_center_crop = False
do_rescale = True
do_normalize = True
do_reduce_labels = False
valid_kwargs = BeitFastImageProcessorKwargs
@classmethod
def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
"""
Overrides the `from_dict` method from the base class to save support of deprecated `reduce_labels` in old configs
"""
image_processor_dict = image_processor_dict.copy()
if "reduce_labels" in image_processor_dict:
image_processor_dict["do_reduce_labels"] = image_processor_dict.pop("reduce_labels")
return super().from_dict(image_processor_dict, **kwargs)
def reduce_label(self, labels: list["torch.Tensor"]):
for idx in range(len(labels)):
label = labels[idx]
label = torch.where(label == 0, torch.tensor(255, dtype=label.dtype), label)
label = label - 1
label = torch.where(label == 254, torch.tensor(255, dtype=label.dtype), label)
labels[idx] = label
return label
def _preprocess(
self,
images: list["torch.Tensor"],
do_reduce_labels: bool,
do_resize: bool,
size: SizeDict,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
if do_reduce_labels:
images = self.reduce_label(images)
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
resized_images_grouped[shape] = stacked_images
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
stacked_images = self.center_crop(stacked_images, crop_size)
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return processed_images
def _preprocess_segmentation_maps(
self,
segmentation_maps,
**kwargs,
):
"""Preprocesses a single segmentation map."""
processed_segmentation_maps = []
for segmentation_map in segmentation_maps:
segmentation_map = self._process_image(
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
)
if segmentation_map.ndim == 2:
segmentation_map = segmentation_map[None, ...]
processed_segmentation_maps.append(segmentation_map)
kwargs["do_normalize"] = False
kwargs["do_rescale"] = False
kwargs["input_data_format"] = ChannelDimension.FIRST
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
return processed_segmentation_maps
def __call__(self, images, segmentation_maps=None, **kwargs):
# Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
# be passed in as positional arguments.
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
@deprecate_kwarg("reduce_labels", new_name="do_reduce_labels", version="4.41.0")
@add_start_docstrings(
"Constructs a fast Beit image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g.
ADE20k). The background label will be replaced by 255.
""",
)
def preprocess(
self,
images: ImageInput,
segmentation_maps: Optional[ImageInput] = None,
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
images = self._prepare_input_images(
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# Prepare segmentation maps
if segmentation_maps is not None:
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
resample = kwargs.pop("resample")
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
)
# Pop kwargs that are not needed in _preprocess
kwargs.pop("default_to_square")
kwargs.pop("data_format")
images = self._preprocess(
images=images,
**kwargs,
)
data = {"pixel_values": images}
if segmentation_maps is not None:
segmentation_maps = self._preprocess_segmentation_maps(
segmentation_maps=segmentation_maps,
**kwargs,
)
data["labels"] = segmentation_maps
return BatchFeature(data=data)
def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.
Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
# TODO: add support for other frameworks
logits = outputs.logits
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if is_torch_tensor(target_sizes):
target_sizes = target_sizes.numpy()
semantic_segmentation = []
for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
__all__ = ["BeitImageProcessorFast"]

View File

@ -175,8 +175,8 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
self.vocab_size = config.vlm_config.text_config.vocab_size
vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
if vlm.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"vlm.language_model.{k}" for k in vlm.language_model._tied_weights_keys]
if vlm._tied_weights_keys is not None:
self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
self.vlm = vlm
self.embedding_dim = self.config.embedding_dim
@ -246,25 +246,25 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
)
def get_input_embeddings(self):
return self.vlm.language_model.get_input_embeddings()
return self.vlm.get_input_embeddings()
def set_input_embeddings(self, value):
self.vlm.language_model.set_input_embeddings(value)
self.vlm.set_input_embeddings(value)
def get_output_embeddings(self):
return self.vlm.language_model.get_output_embeddings()
return self.vlm.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.vlm.language_model.set_output_embeddings(new_embeddings)
self.vlm.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.vlm.language_model.set_decoder(decoder)
self.vlm.set_decoder(decoder)
def get_decoder(self):
return self.vlm.language_model.get_decoder()
return self.vlm.get_decoder()
def tie_weights(self):
return self.vlm.language_model.tie_weights()
return self.vlm.tie_weights()
def resize_token_embeddings(
self,
@ -272,7 +272,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
model_embeds = self.vlm.language_model.resize_token_embeddings(
model_embeds = self.vlm.resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,

View File

@ -0,0 +1,28 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_csm import *
from .modeling_csm import *
from .processing_csm import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,440 @@
# coding=utf-8
# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
logger = logging.get_logger(__name__)
class CsmDepthDecoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`CsmDepthDecoderModel`]. It is used to instantiate an CSM depth decoder
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
a similar configuration to that of the csm-1b.
e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_codebooks (`int`, *optional*, defaults to 32):
Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
backbone_hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations of the backbone model used with this depth decoder.
vocab_size (`int`, *optional*, defaults to 2051):
Vocabulary size of the CsmDepthDecoder model. Defines the number of different audio tokens that can be represented by each codebook.
hidden_size (`int`, *optional*, defaults to 1024):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 4):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 2):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 33):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 2050):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*):
End of stream token id.
rope_theta (`float`, *optional*, defaults to 500000):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
```python
>>> from transformers import CsmDepthDecoder, CsmDepthDecoderConfig
>>> # Initializing a CsmDepthDecoder
>>> configuration = CsmDepthDecoderConfig()
>>> model = CsmDepthDecoderModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "csm_depth_decoder_model"
base_config_key = "depth_decoder_config"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
num_codebooks=32,
backbone_hidden_size=2048,
vocab_size=2051,
hidden_size=1024,
intermediate_size=8192,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=2,
hidden_act="silu",
max_position_embeddings=33,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=None,
eos_token_id=None,
rope_theta=500000,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
**kwargs,
):
if kwargs.pop("tie_word_embeddings", False):
raise ValueError("`tie_word_embeddings=True` is not supported for CsmDepthDecoderConfig")
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=False,
**kwargs,
)
self.num_codebooks = num_codebooks
self.vocab_size = vocab_size
self.backbone_hidden_size = backbone_hidden_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
class CsmConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`CsmForConditionalGeneration`]. It is used to instantiate an CSM
model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the csm-1b.
e.g. [eustlb/csm-1b](https://huggingface.co/eustlb/csm-1b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_codebooks (`int`, *optional*, defaults to 32):
Number of codebooks used in the underlying codec model responsible for tokenizing the audio.
vocab_size (`int`, *optional*, defaults to 2051):
Vocabulary size of the Csm model. Defines the number of different audio tokens that can be represented by each codebook.
text_vocab_size (`int`, *optional*, defaults to 128256):
Vocabulary size of the text input for the Csm model. Defines the number of different text tokens that can be represented.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations of the backbone model.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations of the backbone model.
num_hidden_layers (`int`, *optional*, defaults to 16):
Number of hidden layers in the backbone model Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the backbone model Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf).
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the backbone model Transformer decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 128002):
Padding token id.
codebook_pad_token_id (`int`, *optional*, defaults to 2050):
Padding token id for codebook tokens.
codebook_eos_token_id (`int`, *optional*, defaults to 0):
End of stream token id for codebook tokens.
bos_token_id (`int`, *optional*, defaults to 128000):
Beginning of stream token id.
eos_token_id (`int`, *optional*):
End of stream token id.
audio_token_id (`int`, *optional*, defaults to 128002):
Audio token id in the text input.
audio_eos_token_id (`int`, *optional*, defaults to 128003):
End of stream token id for audio in the text input.
rope_theta (`float`, *optional*, defaults to 500000):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*, defaults to `{'factor': 32.0, 'high_freq_factor': 0.5, 'low_freq_factor': 0.125, 'original_max_position_embeddings': 1024, 'rope_type': 'llama3'}`):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
tie_codebooks_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie the codebook tokens embeddings of the backbone model to the codebook tokens embeddings of the depth decoder.
depth_decoder_config (`CsmDepthDecoderConfig`, *optional*):
Configuration for the depth decoder.
codec_config (`PretrainedConfig`, *optional*):
Configuration for the codec.
```python
>>> from transformers import CsmForConditionalGeneration, CsmConfig
>>> # Initializing a CsmConfig
>>> configuration = CsmConfig()
>>> # Initializing a model
>>> model = CsmForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "csm"
base_config_key = "csm_config"
keys_to_ignore_at_inference = ["past_key_values"]
sub_configs = {
"codec_config": AutoConfig,
"depth_decoder_config": CsmDepthDecoderConfig,
}
def __init__(
self,
num_codebooks=32,
vocab_size=2051,
text_vocab_size=128256,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=16,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=128002,
codebook_pad_token_id=2050,
codebook_eos_token_id=0,
bos_token_id=128000,
eos_token_id=None,
audio_token_id=128002,
audio_eos_token_id=128003,
rope_theta=500000,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
tie_codebooks_embeddings=True,
depth_decoder_config=None,
codec_config=None,
**kwargs,
):
if kwargs.pop("tie_word_embeddings", False):
raise ValueError("`tie_word_embeddings=True` is not supported for CsmConfig")
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=False,
**kwargs,
)
if depth_decoder_config is None:
self.depth_decoder_config = CsmDepthDecoderConfig()
logger.info("depth_decoder_config is None, using default depth decoder config.")
elif isinstance(depth_decoder_config, dict):
self.depth_decoder_config = CsmDepthDecoderConfig(**depth_decoder_config)
elif isinstance(depth_decoder_config, CsmDepthDecoderConfig):
self.depth_decoder_config = depth_decoder_config
if codec_config is None:
self.codec_config = AutoConfig.for_model("mimi")
logger.info("codec_config is None, using default audio encoder config.")
elif isinstance(codec_config, dict):
self.codec_config = AutoConfig.for_model(**codec_config)
elif isinstance(codec_config, PretrainedConfig):
self.codec_config = codec_config
self.text_vocab_size = text_vocab_size
self.num_codebooks = num_codebooks
self.audio_token_id = audio_token_id
self.audio_eos_token_id = audio_eos_token_id
self.codebook_pad_token_id = codebook_pad_token_id
self.codebook_eos_token_id = codebook_eos_token_id
self.tie_codebooks_embeddings = tie_codebooks_embeddings
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
__all__ = [
"CsmDepthDecoderConfig",
"CsmConfig",
]

View File

@ -0,0 +1,339 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
import torch
from tokenizers.processors import TemplateProcessing
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
CsmConfig,
CsmDepthDecoderConfig,
CsmForConditionalGeneration,
CsmProcessor,
MimiModel,
)
from transformers.utils.hub import cached_file
# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
r"backbone\.layers\.(\d+)": r"backbone_model.layers.\1",
r"decoder\.layers\.(\d+)": r"depth_decoder.model.layers.\1",
r"attn": r"self_attn",
r"output_proj": r"o_proj",
r"w1": r"gate_proj",
r"w2": r"down_proj",
r"w3": r"up_proj",
r"text_embeddings": r"embed_text_tokens",
r"audio_embeddings": r"backbone_model.embed_tokens.embed_audio_tokens",
r"codebook0_head": r"lm_head",
r"audio_head": r"depth_decoder.codebooks_head.weight",
r"projection": r"depth_decoder.model.inputs_embeds_projector",
r"sa_norm.scale": r"input_layernorm.weight",
r"mlp_norm.scale": r"post_attention_layernorm.weight",
r"decoder.norm.scale": r"depth_decoder.model.norm.weight",
r"backbone.norm.scale": r"backbone_model.norm.weight",
}
# fmt: on
def permute_for_rope(input_tensor, n_heads, dim1, dim2):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
input_tensor = input_tensor.reshape(dim1, dim2)
input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2)
return input_tensor
def convert_key(key, mapping):
for pattern, replacement in mapping.items():
key = re.sub(pattern, replacement, key)
return key
def write_model(
input_path_or_repo,
model_name,
codec_model_path_or_repo,
output_dir,
safe_serialization=True,
):
print("Converting the model.")
os.makedirs(output_dir, exist_ok=True)
codec_model = MimiModel.from_pretrained(codec_model_path_or_repo)
codec_model.config._attn_implementation_autoset = False
# prepare rope scaling args: the model uses originally
# 1 - for the depth decoder
# rope_theta=500000,
# rope_scaling={
# "factor": 32.0,
# "high_freq_factor": 4.0,
# "low_freq_factor": 1.0,
# "original_max_position_embeddings": 8192,
# "rope_type": "llama3",
# },
# 2 - for the backbone
# rope_theta=500000,
# rope_scaling={
# "factor": 32.0,
# "high_freq_factor": 4.0,
# "low_freq_factor": 1.0,
# "original_max_position_embeddings": 8192,
# "rope_type": "llama3",
# },
#
# Yet we want to use max_position_embeddings=32, resp. 2048
# This will throw warning as we would have original_max_position_embeddings >= max_position_embeddings
# Therefore, we convert values to equivalent ones
depth_decoder_config = CsmDepthDecoderConfig(
rope_scaling={
"factor": 32.0,
"high_freq_factor": 0.0078125,
"low_freq_factor": 0.001953125,
"original_max_position_embeddings": 16,
"rope_type": "llama3",
},
)
config = CsmConfig(
codec_config=codec_model.config,
depth_decoder_config=depth_decoder_config,
rope_scaling={
"factor": 32.0,
"high_freq_factor": 0.5,
"low_freq_factor": 0.125,
"original_max_position_embeddings": 1024,
"rope_type": "llama3",
},
)
params = {
"backbone": {
"num_attention_heads": config.num_attention_heads,
"num_key_value_heads": config.num_key_value_heads,
"dim_per_head": config.head_dim,
"key_value_dim": config.head_dim * config.num_key_value_heads,
"dim": config.hidden_size,
},
"depth_decoder": {
"num_attention_heads": config.depth_decoder_config.num_attention_heads,
"num_key_value_heads": config.depth_decoder_config.num_key_value_heads,
"dim_per_head": config.depth_decoder_config.head_dim,
"key_value_dim": config.depth_decoder_config.head_dim * config.depth_decoder_config.num_key_value_heads,
"dim": config.depth_decoder_config.hidden_size,
},
}
model_path = cached_file(
input_path_or_repo,
model_name,
)
print(f"Fetching all parameters from the checkpoint at {model_path}...")
loaded = torch.load(model_path, map_location="cpu")
print("Converting model...")
state_dict = {}
# -----------------------
# convert parameter names
# -----------------------
# Add codec_model. prefix to every key in the codec model state dict
codec_state_dict = {f"codec_model.{k}": v for k, v in codec_model.state_dict().items()}
state_dict.update(codec_state_dict)
for key, value in loaded.items():
new_key = convert_key(key, ORIGINAL_TO_CONVERTED_KEY_MAPPING)
current_parameter = value
# Post-process the current_parameter.
if re.search("(k|q)_proj.weight", new_key):
params_keys = "backbone" if "backbone" in new_key else "depth_decoder"
if "q_proj" in new_key:
num_heads = params[params_keys]["num_attention_heads"]
dim_per_head = params[params_keys]["dim_per_head"]
param_dim = params[params_keys]["dim"]
dim = params[params_keys]["dim"]
else:
num_heads = params[params_keys]["num_key_value_heads"]
dim_per_head = params[params_keys]["dim_per_head"]
param_dim = params[params_keys]["key_value_dim"]
dim = params[params_keys]["dim"]
current_parameter = permute_for_rope(value, num_heads, param_dim, dim)
state_dict[new_key] = current_parameter.reshape(num_heads * dim_per_head, dim)
state_dict[new_key] = current_parameter
# add the depth decoder embed audio tokens weights, latter tied to the backbone embed audio tokens weights
state_dict["depth_decoder.model.embed_tokens.weight"] = state_dict[
"backbone_model.embed_tokens.embed_audio_tokens.weight"
].clone()
del loaded
gc.collect()
# -------------------------
# load the weights and save
# -------------------------
print("Loading the checkpoint in a Csm model.")
with torch.device("meta"):
model = CsmForConditionalGeneration(config)
model.load_state_dict(state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
del model.config._name_or_path
# default generation config
model.generation_config._from_model_config = False
model.generation_config.max_new_tokens = 125
model.generation_config.do_sample = True
model.generation_config.top_k = 50
model.generation_config.temperature = 0.9
model.generation_config.depth_decoder_do_sample = True
model.generation_config.depth_decoder_top_k = 50
model.generation_config.depth_decoder_temperature = 0.9
print("Saving the model.")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
del state_dict, model
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
CsmForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")
def write_tokenizer(output_dir):
# from https://github.com/SesameAILabs/csm/blob/2d720827843b653c4d67bb4445b1c0a4f59e646f/generator.py#L22-L36
def load_llama3_tokenizer():
"""
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
"""
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
)
return tokenizer
tokenizer = load_llama3_tokenizer()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.save_pretrained(output_dir)
# manually modify in tokenizer_config.json
# "128002": {
# "content": "<|AUDIO|>",
# ...
# }
# "128003": {
# "content": "<|audio_eos|>",
# ...
# }
print(
"Tokenizer saved successfully. Please manually modify in tokenizer_config.json AND tokenizer.json as follows: "
)
print("""
# "128002": {
# "content": "<|AUDIO|>",
# ...
# }
# "128003": {
# "content": "<|audio_eos|>",
# ...
# }
""")
def write_processor(output_dir, codec_model_path_or_repo):
chat_template = "\n{%- for message in messages %}\n {#-- Validate role is a stringified integer --#}\n {%- if not message['role'] is string or not message['role'].isdigit() %}\n {{- raise_exception(\"The role must be an integer or a stringified integer (e.g. '0') designating the speaker id\") }}\n {%- endif %}\n\n {#-- Validate content is a list --#}\n {%- set content = message['content'] %}\n {%- if content is not iterable or content is string %}\n {{- raise_exception(\"The content must be a list\") }}\n {%- endif %}\n\n {#-- Collect content types --#}\n {%- set content_types = content | map(attribute='type') | list %}\n {%- set is_last = loop.last %}\n\n {#-- Last message validation --#}\n {%- if is_last %}\n {%- if 'text' not in content_types %}\n {{- raise_exception(\"The last message must include one item of type 'text'\") }}\n {%- elif (content_types | select('equalto', 'text') | list | length > 1) or (content_types | select('equalto', 'audio') | list | length > 1) %}\n {{- raise_exception(\"At most two items are allowed in the last message: one 'text' and one 'audio'\") }}\n {%- endif %}\n\n {#-- All other messages validation --#}\n {%- else %}\n {%- if content_types | select('equalto', 'text') | list | length != 1\n or content_types | select('equalto', 'audio') | list | length != 1 %}\n {{- raise_exception(\"Each message (except the last) must contain exactly one 'text' and one 'audio' item\") }}\n {%- elif content_types | reject('in', ['text', 'audio']) | list | length > 0 %}\n {{- raise_exception(\"Only 'text' and 'audio' types are allowed in content\") }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n\n{%- for message in messages %}\n {{- bos_token }}\n {{- '[' + message['role'] + ']' }}\n {{- message['content'][0]['text'] }}\n {{- eos_token }}\n {%- if message['content']|length > 1 %}\n {{- '<|AUDIO|><|audio_eos|>' }}\n {%- endif %}\n{%- endfor %}\n"
tokenizer = AutoTokenizer.from_pretrained(output_dir)
feature_extractor = AutoFeatureExtractor.from_pretrained(codec_model_path_or_repo)
processor = CsmProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
chat_template=chat_template,
)
processor.save_pretrained(output_dir)
print("Processor saved successfully.")
def main():
parser = argparse.ArgumentParser(description="Convert Csm weights to HuggingFace format")
parser.add_argument(
"--input_path_or_repo",
type=str,
required=True,
help="Path or repo containing Csm weights",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Name of the model in input_path_or_repo",
)
parser.add_argument(
"--codec_model_path_or_repo",
type=str,
required=True,
help="Path or repo containing the codec model",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--safe_serialization", action="store_true", default=True, help="Whether or not to save using `safetensors`."
)
args = parser.parse_args()
write_model(
args.input_path_or_repo,
args.model_name,
args.codec_model_path_or_repo,
output_dir=args.output_dir,
safe_serialization=args.safe_serialization,
)
write_tokenizer(args.output_dir)
write_processor(args.output_dir, args.codec_model_path_or_repo)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,491 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...generation import (
GenerateDecoderOnlyOutput,
GenerationConfig,
GenerationMixin,
GenerationMode,
)
from ...generation.logits_process import LogitsProcessorList
from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
from ...generation.utils import GenerateNonBeamOutput
from ...utils import logging
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
logger = logging.get_logger(__name__)
@dataclass
class CsmGenerateOutput(GenerateDecoderOnlyOutput):
"""
Outputs of CsmForConditionalGeneration.generate.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
audio (`list(torch.FloatTensor)` of length `batch_size`):
The generated audio.
"""
audio: Optional[List[torch.Tensor]] = None
class CsmGenerationMixin(GenerationMixin):
def _get_stopping_criteria(
self,
*args,
**kwargs,
) -> StoppingCriteriaList:
criteria = super()._get_stopping_criteria(*args, **kwargs)
kept_criteria = StoppingCriteriaList()
for criterion in criteria:
if not isinstance(criterion, MaxLengthCriteria):
logger.warning(
f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored."
)
else:
kept_criteria.append(criterion)
return kept_criteria
def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
"""
This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled.
"""
# extract depth decoder kwargs and remove them from the main kwargs
depth_decoder_kwargs = {
k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_")
}
# remove the depth decoder keys from the original kwargs
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
# initialize the generation config
generation_config, model_kwargs = super()._prepare_generation_config(
generation_config, use_model_defaults, **kwargs
)
self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
# ensure the depth decoder generation config is valid
depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or (
self.config.num_codebooks - 1
)
depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or (
self.config.num_codebooks - 1
)
if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}:
raise ValueError(
f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})"
)
elif self.depth_decoder.generation_config.return_dict_in_generate:
logger.warning(
"depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate"
)
self.depth_decoder.generation_config.return_dict_in_generate = False
self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens
self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens
# Monkey patch the get_generation_mode method to support CSM model
original_get_generation_mode = generation_config.get_generation_mode
def patched_get_generation_mode(assistant_model=None):
generation_mode = original_get_generation_mode(assistant_model)
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
raise ValueError(
f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation."
)
return generation_mode
generation_config.get_generation_mode = patched_get_generation_mode
return generation_config, model_kwargs
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
"""
This method overrides [~generation.utils.GenerationMixin._sample].
To ease maintenance, modifications are marked with the comment "Csm specific".
Indeed, Csm model requires a custom generation sampling step:
1. Infer the backbone model to sample the first codebook token
2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens
3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model
4. Repeat until stopping criteria is met
Csm supports two stopping criterias:
- stop when the generated sequence is at max_length
- stop when all the generated codebook tokens are the codebook_eos_token_id
"""
# init values
# *************** Csm specific ***************
pad_token_id = self.config.codebook_pad_token_id
has_eos_stopping_criteria = generation_config._eos_token_tensor is not None
# ============================================
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
do_sample = generation_config.do_sample
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
# *************** Csm specific ***************
if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None:
# in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids,
# we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned
for criterion in stopping_criteria:
if isinstance(criterion, MaxLengthCriteria):
criterion.max_length -= cur_len
# ============================================
model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
is_prefill = True
while self._has_unfinished_sequences(
this_peer_finished,
synced_gpus,
device=input_ids.device,
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
# *************** Csm specific ***************
model_inputs.update({"output_hidden_states": True})
# ============================================
if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
)
if synced_gpus and this_peer_finished:
continue
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (outputs.attentions,)
if output_hidden_states:
decoder_hidden_states += (outputs.hidden_states,)
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# *************** Csm specific ***************
# infer the depth decoder
first_codebook_ids = next_tokens[:, None]
# adds place holder in position 0 that will be replaced by the backbone_last_hidden_state
depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0)
backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :]
depth_decoder_outputs = self.depth_decoder.generate(
input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone()
)
codebook_ids = (
depth_decoder_outputs
if isinstance(depth_decoder_outputs, torch.Tensor)
else depth_decoder_outputs.sequences
)
# remove the place holder in position 0
codebook_ids = codebook_ids[:, 1:]
next_tokens = codebook_ids
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * (
1 - unfinished_sequences.unsqueeze(-1)
)
# update generated ids, model inputs, and length for next step
if input_ids.ndim == 2:
input_ids = next_tokens[:, None, :]
else:
input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
# ============================================
if streamer is not None:
streamer.put(next_tokens.cpu())
# *************** Csm specific ***************
# for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!!
unfinished_sequences = unfinished_sequences & ~(
input_ids[:, -1, :-1] == self.config.codebook_eos_token_id
).all(-1)
# ============================================
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
# *************** Csm specific ***************
del depth_decoder_outputs
# ============================================
if streamer is not None:
streamer.end()
if return_dict_in_generate:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
def generate(
self,
input_ids: Optional[torch.Tensor] = None,
input_values: Optional[torch.Tensor] = None,
input_values_cutoffs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
output_audio: Optional[bool] = False,
**kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model.
Indeed, Csm model requires a custom generation sampling step:
1. Infer the backbone model to sample the first codebook token
2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens
3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model
4. Repeat until stopping criteria is met
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, do_sample=True)`.
</Tip>
Parameters:
inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*):
The sequence used as a prompt for the backbone model.
input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*):
The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry.
These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`.
input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*):
Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
generation_config ([`~generation.GenerationConfig`], *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which has the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complements the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
intended for advanced users.
synced_gpus (`bool`, *optional*):
Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
output_audio (`bool`, *optional*):
Whether to return the generated audio.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*.
Return:
[`CsmGenerateOutput`] or `torch.LongTensor` or `List[torch.FloatTensor]`: A [`CsmGenerateOutput`]
(if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False`
or a `List[torch.FloatTensor]` otherwise.
Example:
```python
>>> from transformers import CsmProcessor, CsmForConditionalGeneration
>>> from datasets import load_dataset, Audio
>>> model_id = "eustlb/csm-1b"
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
>>> # ensure the audio is 24kHz
>>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
>>> conversation = []
>>> # prepare a conversation with text and corresponding audio
>>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
... conversation.append(
... {
... "role": f"{speaker_id}",
... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
... }
... )
>>> # text prompt
>>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
>>> inputs = processor.apply_chat_template(
... conversation,
... tokenize=True,
... return_dict=True,
... ).to(torch_device)
>>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
>>> audio = model.generate(**inputs, output_audio=True)
>>> processor.save_audio(audio, "output.wav")
```
"""
generate_output = super().generate(
input_ids=input_ids,
input_values=input_values,
input_values_cutoffs=input_values_cutoffs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
synced_gpus=synced_gpus,
streamer=streamer,
**kwargs,
)
generate_returned_dict = not isinstance(generate_output, torch.Tensor)
audio = None
if output_audio:
generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output
# infer the codec model
audio = []
with torch.no_grad():
# =======================================
# TODO: @eustlb, this should be batched !!!
# but requires making sure batched inference of the codec model works as intended
for audio_codes_batch in generated_audio_codes:
eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero()
if eos_idxs.numel() != 0:
cutoff_idx = eos_idxs.min()
else:
cutoff_idx = audio_codes_batch.shape[0]
audio_codes_batch = audio_codes_batch[:cutoff_idx]
codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0))
audio.append(codec_decode_output.audio_values[0, 0])
# =======================================
if generate_returned_dict:
return CsmGenerateOutput(audio=audio, **generate_output)
elif output_audio:
return audio
else:
return generate_output

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,364 @@
# coding=utf-8
# Copyright 2025 Sesame and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
from ...utils import is_soundfile_available, is_torch_available
if is_torch_available():
import torch
if is_soundfile_available():
import soundfile as sf
from ...audio_utils import AudioInput, make_list_of_audio
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import (
PreTokenizedInput,
TextInput,
)
class CsmAudioKwargs(AudioKwargs, total=False):
encoded_length_kwargs: Optional[Dict[str, Any]]
class CsmProcessorKwargs(ProcessingKwargs, total=False):
audio_kwargs: CsmAudioKwargs
_defaults = {
"text_kwargs": {
"padding": True,
"padding_side": "left",
"add_special_tokens": False,
},
"audio_kwargs": {
"encoded_length_kwargs": {
"kernel_sizes": [7, 3, 1, 8, 3, 1, 10, 3, 1, 12, 3, 1, 16, 3, 4],
"strides": [1, 1, 1, 4, 1, 1, 5, 1, 1, 6, 1, 1, 8, 1, 2],
"dilations": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
"use_causal_conv": True,
},
"sampling_rate": 24000,
},
"common_kwargs": {"return_tensors": "pt"},
}
class CsmProcessor(ProcessorMixin):
r"""
Constructs a Csm processor which wraps [`EncodecFeatureExtractor`] and
[`PretrainedTokenizerFast`] into a single processor that inherits both the audio feature extraction and
tokenizer functionalities. See the [`~CsmProcessor.__call__`] for more
information.
The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
```python
from transformers import CsmProcessor
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
audio = ds[0]["audio"]["array"]
processor = CsmProcessor.from_pretrained("eustlb/csm-1b")
processor(
text=["<|begin_of_text|>[0]What are you working on?<|end_of_text|><|AUDIO|><|audio_eos|><|begin_of_text|>[1]I'm figuring out my budget.<|end_of_text|>"],
audio=audio,
text_kwargs = {"padding": False},
audio_kwargs = {"sampling_rate": 16000},
common_kwargs = {"return_tensors": "pt"},
)
# this should error out because EncodecFeatureExtractor expects a 24kHz audio :)
```
Args:
feature_extractor ([`EncodecFeatureExtractor`]):
The feature extractor is a required input.
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["feature_extractor", "tokenizer"]
valid_kwargs = ["chat_template"]
feature_extractor_class = "EncodecFeatureExtractor"
tokenizer_class = "PreTrainedTokenizerFast"
def __init__(
self,
feature_extractor,
tokenizer,
chat_template=None,
):
if not hasattr(tokenizer, "audio_token"):
self.audio_token = "<|AUDIO|>"
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
else:
self.audio_token = tokenizer.audio_token
self.audio_token_id = tokenizer.audio_token_id
if not hasattr(tokenizer, "audio_eos_token"):
self.audio_eos_token = "<|audio_eos|>"
self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(self.audio_eos_token)
else:
self.audio_eos_token = tokenizer.audio_eos_token
self.audio_eos_token_id = tokenizer.audio_eos_token_id
super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
@staticmethod
def _get_encoded_length(audio_length, kernel_sizes=None, strides=None, dilations=None, use_causal_conv=None):
"""
Compute the length of the encoded audio sequence.
Args:
audio_length (int): The length of the audio sequence.
kernel_sizes (List[int]): The kernel sizes for the convolutional layers.
strides (List[int]): The strides for the convolutional layers.
use_causal_conv (bool): Whether to use causal convolutions.
"""
cur_length = audio_length
if kernel_sizes is None or strides is None or dilations is None or use_causal_conv is None:
return cur_length
for kernel_size, stride, dilation in zip(kernel_sizes, strides, dilations):
effective_kernel_size = (kernel_size - 1) * dilation + 1
padding_total = kernel_size - stride
padding_right = padding_total // 2
padding_left = padding_total - padding_right
n_frames = (cur_length - effective_kernel_size + padding_total) / stride + 1
n_frames = math.ceil(n_frames) - 1
ideal_length = n_frames * stride + kernel_size - padding_total
extra_padding = ideal_length - cur_length
if use_causal_conv:
padding_left = padding_total
padding_right = extra_padding
else:
padding_left = padding_left
padding_right = padding_right + extra_padding
cur_length = cur_length + padding_left + padding_right
cur_length = (cur_length - dilation * (kernel_size - 1) - 1) // stride + 1
return cur_length
def save_audio(
self,
audio: AudioInput,
saving_path: Union[str, Path, List[Union[str, Path]]],
**kwargs: Unpack[CsmProcessorKwargs],
):
# TODO: @eustlb, this should be in AudioProcessor
if not is_soundfile_available():
raise ImportError("Please install `soundfile` to save audio files.")
# ensure correct audio input
audio = make_list_of_audio(audio)
# ensure correct saving path
if isinstance(saving_path, (str, Path)):
saving_path = [saving_path]
elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
raise ValueError("Invalid input path. Please provide a string, or a list of strings")
if len(audio) != len(saving_path):
raise ValueError("The number of audio and saving paths must be the same")
output_kwargs = self._merge_kwargs(
CsmProcessorKwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
sampling_rate = audio_kwargs["sampling_rate"]
for audio_value, p in zip(audio, saving_path):
if isinstance(audio_value, torch.Tensor):
audio_value = audio_value.cpu().float().numpy()
sf.write(p, audio_value, sampling_rate)
def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]],
audio: Optional[AudioInput] = None,
output_labels: Optional[bool] = False,
depth_decoder_labels_ratio: Optional[float] = 1.0,
**kwargs: Unpack[CsmProcessorKwargs],
):
r"""
Main method to prepare text(s) and audio to be fed as input to the model. This method forwards the `text`
arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode
the text. To prepare the audio, this method forwards the `audio` arguments to
EncodecFeatureExtractor's [`~EncodecFeatureExtractor.__call__`]. Please refer
to the docstring of the above two methods for more information.
Args:
audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The audio or batch of audio to be prepared. Each audio can be a NumPy array or PyTorch
tensor.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
output_labels (bool, *optional*, default=False):
Whether to return labels for training. Indices will be in `[config.audio_token_id, -100, -101]`.
- `config.audio_token_id` indicates an audio frame (considering sequence length elements as frames)
- `-100` will be ignored in the loss computation
- `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)
depth_decoder_labels_ratio (float, *optional*, default=1.0):
The ratio of audio frames to keep for the depth decoder labels.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **input_values** -- List of audio values to be fed to a model. Returned when `audio` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **labels** -- List of labels for the audio frames. Returned when `output_labels=True`.
"""
output_kwargs = self._merge_kwargs(
CsmProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
text_kwargs = output_kwargs["text_kwargs"]
audio_kwargs = output_kwargs["audio_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
return_tensors = common_kwargs.pop("return_tensors", None)
if return_tensors != "pt":
raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
if isinstance(text, str):
text = [text]
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
n_audio_in_text = [t.count(self.audio_token) for t in text]
n_audio = 0
if audio is not None:
audio = make_list_of_audio(audio)
n_audio = len(audio)
if sum(n_audio_in_text) > 0 and n_audio != sum(n_audio_in_text):
if audio is None:
raise ValueError("No audio were provided, but there are audio tokens in the prompt")
else:
raise ValueError(
f"The number of audio tokens in each text ({n_audio_in_text}) should be the same as the "
f"number of provided audios ({n_audio})."
)
if audio is not None:
encoded_length_kwargs = audio_kwargs.pop("encoded_length_kwargs", {})
num_audio_tokens_list = [
self._get_encoded_length(audio_array.shape[-1], **encoded_length_kwargs) for audio_array in audio
]
num_audio_tokens_list_copy = num_audio_tokens_list.copy()
# expand the text to repeat the audio token for the corresponding number of frames
expanded_text = []
for sample in text:
replace_str = []
while self.audio_token in sample:
num_audio_tokens = num_audio_tokens_list_copy.pop(0)
expanded_audio_token = self.audio_token * num_audio_tokens
replace_str.append(expanded_audio_token)
sample = sample.replace(self.audio_token, "<placeholder>", 1)
while "<placeholder>" in sample:
sample = sample.replace("<placeholder>", replace_str.pop(0), 1)
expanded_text.append(sample)
text = expanded_text
encoding = self.tokenizer(text, **text_kwargs)
data = {}
data.update(encoding)
if audio is not None:
audio_kwargs.pop("return_attention_mask", None) # not supported by the feature extractor
concatenated_audio, input_values_cutoffs = [], []
offset = 0
for n_audio in n_audio_in_text:
if n_audio == 0:
concatenated_audio.append(np.zeros(0))
input_values_cutoffs.append(torch.tensor([-1]))
else:
concatenated_audio.append(
np.concatenate(
[
el.cpu().numpy() if isinstance(el, torch.Tensor) else el
for el in audio[offset : offset + n_audio]
],
axis=-1,
)
)
input_values_cutoffs.append(
torch.tensor([el.shape[-1] for el in audio[offset : offset + n_audio]]).cumsum(dim=-1)
)
offset += n_audio
audio_inputs = self.feature_extractor(concatenated_audio, **audio_kwargs)
audio_inputs.pop("padding_mask", None) # not applicable here
data.update(audio_inputs)
# pad and stack the audio cut idxs
max_len = max(cut_idxs.shape[-1] for cut_idxs in input_values_cutoffs)
input_values_cutoffs = [
torch.nn.functional.pad(cut_idxs, (0, max_len - cut_idxs.shape[-1]), value=-1)
for cut_idxs in input_values_cutoffs
]
data["input_values_cutoffs"] = torch.stack(input_values_cutoffs, dim=0)
if output_labels:
audio_frame_idxs = (data["input_ids"] == self.audio_token_id).nonzero()
n_audio_frames = audio_frame_idxs.shape[0]
if depth_decoder_labels_ratio <= 1.0:
rand_idxs = torch.randperm(n_audio_frames)[: int(n_audio_frames * (1 - depth_decoder_labels_ratio))]
skip_frames_idxs = audio_frame_idxs[rand_idxs]
else:
skip_frames_idxs = audio_frame_idxs
labels = torch.where(data["input_ids"] == self.audio_token_id, data["input_ids"], -100)
labels[skip_frames_idxs[:, 0], skip_frames_idxs[:, 1]] = -101
data["labels"] = labels
return BatchFeature(data=data, tensor_type=return_tensors)
__all__ = ["CsmProcessor"]

View File

@ -33,7 +33,7 @@ class DFineConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`DFineModel`]. It is used to instantiate a D-FINE
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine_x_coco"](https://huggingface.co/ustc-community/dfine_x_coco").
defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine-xlarge-coco"](https://huggingface.co/ustc-community/dfine-xlarge-coco").
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

View File

@ -1790,8 +1790,8 @@ class DFineForObjectDetection(DFinePreTrainedModel):
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = load_image(url)
>>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine_x_coco")
>>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine_x_coco")
>>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-xlarge-coco")
>>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-xlarge-coco")
>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt")

View File

@ -52,7 +52,7 @@ class DFineConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`DFineModel`]. It is used to instantiate a D-FINE
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine_x_coco"](https://huggingface.co/ustc-community/dfine_x_coco").
defaults will yield a similar configuration to that of D-FINE-X-COCO "[ustc-community/dfine-xlarge-coco"](https://huggingface.co/ustc-community/dfine-xlarge-coco").
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
@ -922,8 +922,8 @@ class DFineForObjectDetection(RTDetrForObjectDetection, DFinePreTrainedModel):
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = load_image(url)
>>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine_x_coco")
>>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine_x_coco")
>>> image_processor = AutoImageProcessor.from_pretrained("ustc-community/dfine-xlarge-coco")
>>> model = DFineForObjectDetection.from_pretrained("ustc-community/dfine-xlarge-coco")
>>> # prepare image for the model
>>> inputs = image_processor(images=image, return_tensors="pt")

View File

@ -773,7 +773,8 @@ class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedMod
sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
cls_token = sequence_output[:, 0]
patch_tokens = sequence_output[:, 1:]
# cls and register tokens should not be included in patch tokens variable
patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)

View File

@ -19,6 +19,7 @@ from typing import Optional, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ....transformers.models.dinov2.modeling_dinov2 import (
Dinov2Backbone,
@ -29,7 +30,7 @@ from ....transformers.models.dinov2.modeling_dinov2 import (
Dinov2PreTrainedModel,
)
from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import BackboneOutput
from ...modeling_outputs import BackboneOutput, ImageClassifierOutput
from ...utils import logging, torch_int
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
@ -314,7 +315,76 @@ class Dinov2WithRegistersModel(Dinov2Model):
class Dinov2WithRegistersForImageClassification(Dinov2ForImageClassification):
pass
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.dinov2_with_registers(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
cls_token = sequence_output[:, 0]
# cls and register tokens should not be included in patch tokens variable
patch_tokens = sequence_output[:, 1 + self.config.num_register_tokens :]
linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
logits = self.classifier(linear_input)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class Dinov2WithRegistersBackbone(Dinov2Backbone):

View File

@ -156,14 +156,18 @@ class DonutProcessor(ProcessorMixin):
output = {}
while tokens:
start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
if start_token is None:
# We want r"<s_(.*?)>" but without ReDOS risk, so do it manually in two parts
potential_start = re.search(r"<s_", tokens, re.IGNORECASE)
if potential_start is None:
break
key = start_token.group(1)
start_token = tokens[potential_start.start() :]
if ">" not in start_token:
break
start_token = start_token[: start_token.index(">") + 1]
key = start_token[len("<s_") : -len(">")]
key_escaped = re.escape(key)
end_token = re.search(rf"</s_{key_escaped}>", tokens, re.IGNORECASE)
start_token = start_token.group()
if end_token is None:
tokens = tokens.replace(start_token, "")
else:

View File

@ -1778,13 +1778,16 @@ EMU3_INPUTS_DOCSTRING = r"""
"""
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compilable
class Emu3Model(Emu3PreTrainedModel):
_checkpoint_conversion_mapping = {"text_model.model": "text_model"}
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
self.text_model = Emu3ForCausalLM._from_config(config.text_config)
self.text_model = Emu3TextModel._from_config(config.text_config)
if self.text_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@ -1833,14 +1836,12 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
image = self.vqmodel.decode(image_tokens)
return image
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
@ -1848,10 +1849,99 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
return outputs
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
base_model_prefix = ""
_checkpoint_conversion_mapping = {
"^text_model.model": "model.text_model",
"^vqmodel": "model.vqmodel",
"^text_model.lm_head": "lm_head",
}
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
self.model = Emu3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
# Make modules available throught conditional class for BC
@property
def text_model(self):
return self.model.text_model
@property
def vqmodel(self):
return self.model.vqmodel
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> CausalLMOutputWithPast:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1906,25 +1996,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
return self.text_model(
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -1933,8 +2007,25 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
@ -1968,5 +2059,68 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [
"Emu3ForConditionalGeneration",
"Emu3ForCausalLM",
"Emu3TextModel",
"Emu3PreTrainedModel",
"Emu3VQVAE",
"Emu3Model",
]

View File

@ -1121,13 +1121,16 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
super().forward()
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["text_model.lm_head.weight"]
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compilable
class Emu3Model(Emu3PreTrainedModel):
_checkpoint_conversion_mapping = {"text_model.model": "text_model"}
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
self.text_model = Emu3ForCausalLM._from_config(config.text_config)
self.text_model = Emu3TextModel._from_config(config.text_config)
if self.text_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@ -1176,14 +1179,12 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
image = self.vqmodel.decode(image_tokens)
return image
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
@ -1191,10 +1192,99 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
return outputs
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
base_model_prefix = ""
_checkpoint_conversion_mapping = {
"^text_model.model": "model.text_model",
"^vqmodel": "model.vqmodel",
"^text_model.lm_head": "lm_head",
}
_supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
self.model = Emu3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
# Make modules available throught conditional class for BC
@property
def text_model(self):
return self.model.text_model
@property
def vqmodel(self):
return self.model.vqmodel
@can_return_tuple
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> CausalLMOutputWithPast:
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -1249,25 +1339,9 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
return self.text_model(
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
@ -1276,8 +1350,25 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
@ -1311,6 +1402,62 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = [
"Emu3ForConditionalGeneration",
@ -1318,4 +1465,5 @@ __all__ = [
"Emu3TextModel",
"Emu3PreTrainedModel",
"Emu3VQVAE",
"Emu3Model",
]

View File

@ -21,9 +21,9 @@ import torch.utils.checkpoint
from torch import nn
from ...generation import GenerationMixin
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...models.auto.modeling_auto import AutoModelForCausalLM
from ...models.auto.modeling_auto import AutoModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_fuyu import FuyuConfig
@ -143,18 +143,17 @@ FUYU_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.",
"""The Fuyu model which consists of a vision backbone and a language model, without a language modeling head.""",
FUYU_START_DOCSTRING,
)
class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
class FuyuModel(FuyuPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = AutoModel.from_config(config.text_config)
self.vision_embed_tokens = nn.Linear(
config.patch_size * config.patch_size * config.num_channels, config.hidden_size
)
@ -169,18 +168,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def gather_continuous_embeddings(
self,
word_embeddings: torch.Tensor,
@ -224,56 +211,21 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return output_embeddings
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
image_patches: Optional[
torch.Tensor
] = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor = None,
image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Returns:
Examples:
```python
>>> from transformers import FuyuProcessor, FuyuForCausalLM
>>> from PIL import Image
>>> import requests
>>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
>>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
>>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a coco-style caption.\n"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> outputs = model(**inputs)
>>> generated_ids = model.generate(**inputs, max_new_tokens=7)
>>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
>>> print(generation_text[0])
A blue bus parked on the side of a road.
```"""
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -327,7 +279,6 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
labels=labels,
use_cache=use_cache,
return_dict=return_dict,
**kwargs,
@ -335,6 +286,139 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return outputs
@add_start_docstrings(
"Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.",
FUYU_START_DOCSTRING,
)
class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_embed_tokens": "model.vision_embed_tokens",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: FuyuConfig):
super().__init__(config)
self.model = FuyuModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
@add_start_docstrings_to_model_forward(FUYU_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
image_patches: torch.Tensor = None, # [batch_size, num_total_patches, patch_size_ x patch_size x num_channels ]
image_patches_indices: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Optional[int] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Returns:
Examples:
```python
>>> from transformers import FuyuProcessor, FuyuForCausalLM
>>> from PIL import Image
>>> import requests
>>> processor = FuyuProcessor.from_pretrained("adept/fuyu-8b")
>>> model = FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
>>> url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a coco-style caption.\n"
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> outputs = model(**inputs)
>>> generated_ids = model.generate(**inputs, max_new_tokens=7)
>>> generation_text = processor.batch_decode(generated_ids[:, -7:], skip_special_tokens=True)
>>> print(generation_text[0])
A blue bus parked on the side of a road.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
image_patches=image_patches,
image_patches_indices=image_patches_indices,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
return_dict=return_dict,
# don't pass kwargs because Persimmon-backbone doesn't accept FA2 kwargs yet, TODO: raushan
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
@ -373,4 +457,4 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
return reordered_past
__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel"]
__all__ = ["FuyuForCausalLM", "FuyuPreTrainedModel", "FuyuModel"]

View File

@ -32,11 +32,12 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
@ -46,7 +47,7 @@ from ...utils import (
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
@ -60,6 +61,39 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Gemma3Config"
@dataclass
class Gemma3ModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Gemma3 outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class Gemma3CausalLMOutputWithPast(ModelOutput):
"""
@ -88,7 +122,7 @@ class Gemma3CausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
@ -480,7 +514,7 @@ GEMMA3_START_DOCSTRING = r"""
)
class Gemma3PreTrainedModel(PreTrainedModel):
config_class = Gemma3Config
base_model_prefix = "language_model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = [
"Gemma3DecoderLayer",
@ -1066,20 +1100,19 @@ class Gemma3MultiModalProjector(nn.Module):
@add_start_docstrings(
"""The GEMMA3 model which consists of a vision backbone and a language model.""",
"""Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head.""",
GEMMA3_START_DOCSTRING,
)
class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
class Gemma3Model(Gemma3PreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: Gemma3Config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
@ -1091,18 +1124,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def _update_causal_mask(
self,
attention_mask,
@ -1188,11 +1209,10 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
@ -1203,6 +1223,145 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**lm_kwargs,
) -> Union[Tuple, Gemma3ModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
return Gemma3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
@add_start_docstrings(
"""The Gemma3 model which consists of a vision backbone and a language model.""",
GEMMA3_START_DOCSTRING,
)
class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Gemma3Config):
super().__init__(config)
self.model = Gemma3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
@ -1260,80 +1419,34 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
```
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs: CausalLMOutputWithPast = self.language_model(
attention_mask=causal_mask,
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
@ -1356,13 +1469,17 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -1381,7 +1498,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1401,15 +1518,73 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs
def tie_weights(self):
return self.language_model.tie_weights()
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Gemma3PreTrainedModel", "Gemma3TextModel", "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration"]
__all__ = [
"Gemma3PreTrainedModel",
"Gemma3TextModel",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3Model",
]

View File

@ -26,11 +26,12 @@ import torch.utils.checkpoint
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
@ -50,7 +51,12 @@ from ..gemma2.modeling_gemma2 import (
apply_rotary_pos_emb,
eager_attention_forward,
)
from ..paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from ..paligemma.modeling_paligemma import (
PaligemmaCausalLMOutputWithPast,
PaliGemmaForConditionalGeneration,
PaliGemmaModel,
PaligemmaModelOutputWithPast,
)
from ..siglip import SiglipVisionConfig
@ -302,43 +308,13 @@ class Gemma3Config(PretrainedConfig):
@dataclass
class Gemma3CausalLMOutputWithPast(ModelOutput):
"""
Base class for Gemma3 causal language model (or autoregressive) outputs.
class Gemma3ModelOutputWithPast(PaligemmaModelOutputWithPast):
pass
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class Gemma3CausalLMOutputWithPast(PaligemmaCausalLMOutputWithPast):
pass
class Gemma3TextScaledWordEmbedding(nn.Embedding):
@ -545,7 +521,7 @@ GEMMA3_START_DOCSTRING = None
class Gemma3PreTrainedModel(Gemma2PreTrainedModel):
base_model_prefix = "language_model"
base_model_prefix = ""
_no_split_modules = [
"Gemma3DecoderLayer",
"SiglipVisionEmbeddings",
@ -755,10 +731,7 @@ class Gemma3MultiModalProjector(nn.Module):
return projected_vision_outputs.type_as(vision_outputs)
class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
def tie_weights(self):
return self.language_model.tie_weights()
class Gemma3Model(PaliGemmaModel):
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Projects the last hidden state from the vision model into language model space.
@ -844,11 +817,10 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
@ -859,6 +831,106 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**lm_kwargs,
) -> Union[Tuple, Gemma3ModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
return Gemma3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
@add_start_docstrings(
"""The Gemma3 model which consists of a vision backbone and a language model.""",
GEMMA3_START_DOCSTRING,
)
class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
@ -916,80 +988,34 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
```
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
)
outputs: CausalLMOutputWithPast = self.language_model(
attention_mask=causal_mask,
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
@ -1012,13 +1038,17 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -1037,7 +1067,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1057,7 +1087,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
causal_mask = self._update_causal_mask(
causal_mask = self.model._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
)
model_inputs["attention_mask"] = causal_mask
@ -1072,4 +1102,5 @@ __all__ = [
"Gemma3TextModel",
"Gemma3ForCausalLM",
"Gemma3ForConditionalGeneration",
"Gemma3Model",
]

View File

@ -28,11 +28,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
@ -40,7 +38,7 @@ from ...utils import (
can_return_tuple,
replace_return_docstrings,
)
from ..auto import AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig
@ -545,7 +543,7 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -557,6 +555,39 @@ class GotOcr2CausalLMOutputWithPast(ModelOutput):
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class GotOcr2ModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for GotOcr2 outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
GOT_OCR2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@ -575,14 +606,13 @@ GOT_OCR2_START_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"The bare GotOcr2 Model outputting raw hidden-states without any specific head on top.",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2PreTrainedModel(PreTrainedModel):
config_class = GotOcr2Config
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["GotOcr2VisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -680,23 +710,18 @@ GOT_OCR2_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The GOT_OCR2 model which consists of a vision backbone and a language model.""",
"""The GotOcr2 model which consists of a vision backbone and a language model, without a language modeling head.""",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
class GotOcr2Model(GotOcr2PreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
self.multi_modal_projector = GotOcr2MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = config.pad_token_id
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -705,18 +730,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -732,13 +745,124 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = GotOcr2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The GOT_OCR2 model which consists of a vision backbone and a language model.""",
GOT_OCR2_START_DOCSTRING,
)
class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.model = GotOcr2Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -747,9 +871,10 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> GotOcr2CausalLMOutputWithPast:
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -794,37 +919,15 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
"You should keep in mind what features from the module should be used, especially
when you're planning to sell a template."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -832,29 +935,18 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return GotOcr2CausalLMOutputWithPast(
loss=loss,
@ -862,7 +954,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -878,7 +970,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -895,5 +987,60 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["GotOcr2PreTrainedModel", "GotOcr2ForConditionalGeneration"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"]

View File

@ -14,16 +14,17 @@
# limitations under the License.
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaModelOutputWithPast,
LlavaPreTrainedModel,
)
from transformers.models.sam.modeling_sam import SamMLPBlock, SamVisionAttention, SamVisionEncoder, SamVisionLayer
@ -36,7 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModelForCausalLM
from ..auto import CONFIG_MAPPING, AutoConfig
if is_vision_available():
@ -278,6 +279,10 @@ class GotOcr2CausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class GotOcr2ModelOutputWithPast(LlavaModelOutputWithPast):
pass
class GotOcr2PreTrainedModel(LlavaPreTrainedModel):
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
@ -368,22 +373,11 @@ GOT_OCR2_INPUTS_DOCSTRING = r"""
"""
class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
class GotOcr2Model(LlavaModel):
def __init__(self, config: GotOcr2Config):
super().__init__(config)
self.vision_tower = GotOcr2VisionEncoder(config.vision_config)
self.multi_modal_projector = GotOcr2MultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = config.pad_token_id
self.post_init()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -399,13 +393,81 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, GotOcr2ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
)
output = GotOcr2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -414,9 +476,10 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> GotOcr2CausalLMOutputWithPast:
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@ -461,37 +524,15 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
"You should keep in mind what features from the module should be used, especially
when you're planning to sell a template."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -499,29 +540,18 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
)
logits = outputs.logits
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return GotOcr2CausalLMOutputWithPast(
loss=loss,
@ -529,7 +559,7 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
@ -537,5 +567,6 @@ __all__ = [
"GotOcr2VisionConfig",
"GotOcr2Config",
"GotOcr2PreTrainedModel",
"GotOcr2Model",
"GotOcr2ForConditionalGeneration",
]

View File

@ -1111,7 +1111,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
)
return model_inputs
def _get_initial_cache_position(self, input_ids, model_kwargs):
def _get_initial_cache_position(self, seq_length, device, model_kwargs):
"""
Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length.
Since gpt bigcode is special, the method is overridden here, other models use it from `generation.utils.py`.
@ -1125,8 +1125,8 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin):
if "inputs_embeds" in model_kwargs:
cur_len = model_kwargs["inputs_embeds"].shape[1]
else:
cur_len = input_ids.shape[-1]
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
cur_len = seq_length
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=device)
return model_kwargs
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)

View File

@ -166,6 +166,8 @@ class GraniteMoeConfig(PretrainedConfig):
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# this model has rope embedding type, hardcoded for BC
self.position_embedding_type = "rope"
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

View File

@ -13,25 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutputWithPast,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
@ -439,10 +438,9 @@ class GraniteMoeAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@ -455,260 +453,75 @@ class GraniteMoeAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
cos, sin = position_embeddings if position_embeddings is not None else (None, None)
if position_embeddings is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe
# TODO cyril: modular
class GraniteMoeFlashAttention2(GraniteMoeAttention):
"""
GraniteMoe flash attention module. This module inherits from `GraniteMoeAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (GraniteMoeRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
softmax_scale=self.scaling,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe
# TODO cyril: modular
class GraniteMoeSdpaAttention(GraniteMoeAttention):
"""
GraniteMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GraniteMoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
# Adapted from GraniteMoeAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GraniteMoeModel is using GraniteMoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
bsz, q_len, _ = hidden_states.size()
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
return attn_output, attn_weights
GRANITEMOE_ATTENTION_CLASSES = {
"eager": GraniteMoeAttention,
"flash_attention_2": GraniteMoeFlashAttention2,
"sdpa": GraniteMoeSdpaAttention,
}
class GraniteMoeDecoderLayer(nn.Module):
class GraniteMoeDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GraniteMoeConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GRANITEMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx)
self.block_sparse_moe = GraniteMoeMoE(config)
self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -827,13 +640,12 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, GraniteMoeRMSNorm):
@ -947,8 +759,8 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
# rope
self.rotary_emb = GraniteMoeRotaryEmbedding(config)
self.position_embedding_type = config.position_embedding_type
self.rotary_emb = GraniteMoeRotaryEmbedding(config) if self.position_embedding_type == "rope" else None
# Initialize weights and apply final processing
self.post_init()
@ -1019,8 +831,10 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
# embed positions
hidden_states = inputs_embeds
position_embeddings = None
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
if self.rotary_emb is not None:
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@ -1032,31 +846,17 @@ class GraniteMoeModel(GraniteMoePreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_router_logits,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
@ -1265,6 +1065,7 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
@ -1273,6 +1074,13 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
@ -1315,8 +1123,10 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
cache_position=cache_position,
)
# Only compute necessary logits
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
logits = logits / self.config.logits_scaling
loss = None

View File

@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_granitemoehybrid import *
from .modeling_granitemoehybrid import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,256 @@
# coding=utf-8
# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GraniteMoeHybrid model configuration"""
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging
logger = logging.get_logger(__name__)
class GraniteMoeHybridConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GraniteMoeHybridConfig`]. It is used to
instantiate an GraniteMoeHybrid model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the GraniteMoeHybrid model. Defines the number of different tokens that
can be represented by the `inputs_ids` passed when calling [`GraniteMoeHybridModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Only relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
embedding_multiplier (`float`, *optional*, defaults to 1.0): embedding multiplier.
logits_scaling (`float`, *optional*, defaults to 1.0): divisor for output logits.
residual_multiplier (`float`, *optional*, defaults to 1.0): residual multiplier.
attention_multiplier (`float`, *optional*, defaults to 1.0): attention multiplier.
num_local_experts (`int`, *optional*, defaults to 8): total number of experts.
num_experts_per_tok (`int`, *optional*, defaults to 2): number of experts per token.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router auxialiary loss coefficient
shared_intermediate_size (`int`, *optional*, defaults to 1024): intermediate size for shared experts.
position_embedding_type (`str`, *optional*): Positional embedding
type to be used; defaults to None. Allowed options: `[None, "rope"]`
layer_types (`List`, *optional*): list of strings to be used as layer types.
Allowed choices: "mamba", "attention".
mamba_n_heads (`int`, *optional*, defaults to 128):
The number of mamba heads used.
mamba_n_groups (`int`, *optional*, defaults to 1):
The number of the mamba groups used.
mamba_d_state (`int`, *optional*, defaults to 256):
The dimension the mamba latent state space.
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
Head embedding dimension size.
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel.
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size.
mamba_chunk_size (`int`, *optional*, defaults to 256):
The chunks in which to break the sequence when doing prefill/training.
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"])
of the mamba mixer block.
```python
>>> from transformers import GraniteMoeHybridModel, GraniteMoeHybridConfig
>>> # Initializing a GraniteMoeHybrid config
>>> configuration = GraniteMoeHybridConfig()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "granitemoehybrid"
attribute_map = {
"layers_block_type": "layer_types",
}
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
embedding_multiplier=1.0,
logits_scaling=1.0,
residual_multiplier=1.0,
attention_multiplier=1.0,
num_local_experts=8,
num_experts_per_tok=2,
output_router_logits=False,
router_aux_loss_coef=0.001,
shared_intermediate_size=1024,
position_embedding_type=None,
layer_types=None,
mamba_n_heads=128,
mamba_n_groups=1,
mamba_d_state=256,
mamba_d_head="auto",
mamba_d_conv=4,
mamba_expand=2,
mamba_chunk_size=256,
mamba_conv_bias=True,
mamba_proj_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.embedding_multiplier = embedding_multiplier
self.logits_scaling = logits_scaling
self.residual_multiplier = residual_multiplier
self.attention_multiplier = attention_multiplier
self.attention_dropout = attention_dropout
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.shared_intermediate_size = shared_intermediate_size
self.position_embedding_type = position_embedding_type
mamba_intermediate = mamba_expand * hidden_size
if layer_types is not None and any(layer_type not in ["mamba", "attention"] for layer_type in layer_types):
raise ValueError("layer_types must be a list strings in [`mamba` `attention`]")
if mamba_intermediate % mamba_n_heads != 0:
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
# for the mamba_v2, must satisfy the following
if mamba_d_head == "auto":
mamba_d_head = mamba_intermediate // mamba_n_heads
if mamba_d_head * mamba_n_heads != mamba_intermediate:
raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
self.mamba_n_heads = mamba_n_heads
self.mamba_d_head = mamba_d_head
self.mamba_n_groups = mamba_n_groups
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_chunk_size = mamba_chunk_size
self.mamba_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.mamba_expand = mamba_expand
self.layer_types = layer_types
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
if self.position_embedding_type == "rope":
rope_config_validation(self)
# overwrite the function to use in `HybridMambaAttentionDynamicCache`
@property
def layers_block_type(self):
return self.layer_types if self.layer_types else ["mamba"] * self.num_hidden_layers
__all__ = ["GraniteMoeHybridConfig"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,510 @@
# coding=utf-8
# Copyright 2025 IBM and the HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from ...cache_utils import Cache
from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
)
from ..bamba.configuration_bamba import BambaConfig
from ..bamba.modeling_bamba import (
BambaMixer,
BambaRMSNormGated,
HybridMambaAttentionDynamicCache,
)
from ..granitemoeshared.modeling_granitemoeshared import (
GraniteMoeSharedAttention,
GraniteMoeSharedDecoderLayer,
GraniteMoeSharedForCausalLM,
GraniteMoeSharedMLP,
GraniteMoeSharedModel,
GraniteMoeSharedPreTrainedModel,
)
from .configuration_granitemoehybrid import GraniteMoeHybridConfig
logger = logging.get_logger(__name__)
class GraniteMoeHybridAttention(GraniteMoeSharedAttention):
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
super().__init__(config, layer_idx)
class GraniteMoeHybridMambaLayer(BambaMixer):
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
super().__init__(BambaConfig(config), layer_idx)
class GraniteMoeHybridRMSNormGated(BambaRMSNormGated):
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps)
class GraniteMoeHybridMLP(GraniteMoeSharedMLP):
def __init__(self, config: GraniteMoeHybridConfig):
super().__init__(config)
class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer):
def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.shared_mlp = GraniteMoeHybridMLP(config)
# Either attention or mamba will be initialized, depending on the layer type.
self.self_attn = None
self.mamba = None
if config.layers_block_type[layer_idx] == "mamba":
self.mamba = GraniteMoeHybridMambaLayer(config, layer_idx)
else:
self.self_attn = GraniteMoeHybridAttention(config, layer_idx)
self.layer_type = config.layers_block_type[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
output_router_logits (`bool`, *optional*):
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
should not be returned during inference.
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.mamba is not None:
hidden_states = self.mamba(
hidden_states=hidden_states,
cache_position=cache_position,
cache_params=past_key_value,
attention_mask=attention_mask,
)
# No attention weights for state space layers
self_attn_weights = None
else:
hidden_states, self_attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states * self.residual_multiplier
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
moe_hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (past_key_value,)
if output_router_logits:
outputs += (router_logits,)
return outputs
GRANITEMOEHYBRID_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`GraniteMoeHybridConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare GraniteMoeHybrid Model outputting raw hidden-states without any specific head on top.",
GRANITEMOEHYBRID_START_DOCSTRING,
)
class GraniteMoeHybridPreTrainedModel(GraniteMoeSharedPreTrainedModel):
config_class = GraniteMoeHybridConfig
_no_split_modules = ["GraniteMoeHybridDecoderLayer"]
_is_stateful = True
def _init_weights(self, module):
super()._init_weights()
# Initialize Mamba modules
if isinstance(module, (nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, GraniteMoeHybridMambaLayer):
module.dt_bias.data.fill_(1.0)
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
module.D.data.fill_(1.0)
elif isinstance(module, GraniteMoeHybridRMSNormGated):
module.weight.data.fill_(1.0)
GRANITEMOEHYBRID_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare GraniteMoeHybrid Model outputting raw hidden-states without any specific head on top.",
GRANITEMOEHYBRID_START_DOCSTRING,
)
class GraniteMoeHybridModel(GraniteMoeSharedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers.
Each layer is a [`GraniteMoeHybridDecoderLayer`]
Args:
config: GraniteMoeHybridConfig
"""
def __init__(self, config: GraniteMoeHybridConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[GraniteMoeHybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
@can_return_tuple
@add_start_docstrings_to_model_forward(GRANITEMOEHYBRID_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = inputs_embeds * self.embedding_multiplier
## overwritten because `HybridMambaAttentionDynamicCache` is needed
if use_cache and past_key_values is None:
logger.warning_once(
"GraniteMoeHybrid requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. "
"Because one was not provided, no cache will be returned."
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
# embed positions
hidden_states = inputs_embeds
position_embeddings = None
# create position embeddings to be shared across the decoder layers
if self.rotary_emb is not None:
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
next_decoder_cache = None
for decoder_layer in self.layers:
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=layer_mask,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
if layer_outputs[1] is not None:
# append attentions only of attention layers. Mamba layers return `None` as the attention weights
all_self_attns += (layer_outputs[1],)
if output_router_logits:
if layer_outputs[-1] is not None:
# append router logits only of expert layers. Regular MLP layers return `None` as the router logits
all_router_logits += (layer_outputs[-1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
router_logits=all_router_logits,
)
def _update_mamba_mask(self, attention_mask, cache_position):
"""
No need for zeroing states when
1. Cached forward
2. Attending to all inputs
"""
mamba_mask = attention_mask
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
mamba_mask = None
return mamba_mask
class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: GraniteMoeHybridConfig):
super().__init__(config)
self.model = GraniteMoeHybridModel(config)
# Initialize weights and apply final processing
self.post_init()
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
empty_past_kv = past_key_values is None
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if not empty_past_kv:
if (
inputs_embeds is not None # Exception 1
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
else:
past_key_values = HybridMambaAttentionDynamicCache(
self.config, input_ids.shape[0], self.dtype, device=self.device
)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if not empty_past_kv:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and empty_past_kv:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cache_position": cache_position,
}
)
return model_inputs
def _supports_default_dynamic_cache(self) -> bool:
"""
Function overwritten as this class uses its own `HybridMambaAttentionDynamicCache`
and do not need to initialize the Cache in advance in order to save memory
(because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed
for `HybridMambaAttentionDynamicCache`).
"""
return False
__all__ = ["GraniteMoeHybridForCausalLM", "GraniteMoeHybridModel", "GraniteMoeHybridPreTrainedModel"]

View File

@ -169,6 +169,8 @@ class GraniteMoeSharedConfig(PretrainedConfig):
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# this model has rope embedding type, hardcoded for BC
self.position_embedding_type = "rope"
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

View File

@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@ -29,10 +29,10 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
@ -300,6 +300,33 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoeShared
# no longer copied after attention refactors
class GraniteMoeSharedAttention(nn.Module):
@ -343,10 +370,9 @@ class GraniteMoeSharedAttention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
@ -359,262 +385,48 @@ class GraniteMoeSharedAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
cos, sin = position_embeddings if position_embeddings is not None else (None, None)
if position_embeddings is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoeShared
# TODO cyril: modular
class GraniteMoeSharedFlashAttention2(GraniteMoeSharedAttention):
"""
GraniteMoeShared flash attention module. This module inherits from `GraniteMoeSharedAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (GraniteMoeSharedRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
target_dtype = self.q_proj.weight.dtype
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
softmax_scale=self.scaling,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoeShared
# TODO cyril: modular
class GraniteMoeSharedSdpaAttention(GraniteMoeSharedAttention):
"""
GraniteMoeShared attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`GraniteMoeSharedAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from GraniteMoeSharedAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"GraniteMoeSharedModel is using GraniteMoeSharedSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
GRANITEMOESHARED_ATTENTION_CLASSES = {
"eager": GraniteMoeSharedAttention,
"flash_attention_2": GraniteMoeSharedFlashAttention2,
"sdpa": GraniteMoeSharedSdpaAttention,
}
class GraniteMoeSharedDecoderLayer(nn.Module):
class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GRANITEMOESHARED_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx
)
self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx)
self.block_sparse_moe = GraniteMoeSharedMoE(config)
self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -632,7 +444,7 @@ class GraniteMoeSharedDecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@ -739,13 +551,12 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
_supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, GraniteMoeSharedRMSNorm):
@ -893,8 +704,8 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
# rope
self.rotary_emb = GraniteMoeSharedRotaryEmbedding(config)
self.position_embedding_type = config.position_embedding_type
self.rotary_emb = GraniteMoeSharedRotaryEmbedding(config) if self.position_embedding_type == "rope" else None
# Initialize weights and apply final processing
self.post_init()
@ -965,8 +776,10 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
# embed positions
hidden_states = inputs_embeds
position_embeddings = None
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
if self.rotary_emb is not None:
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@ -978,31 +791,17 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
output_router_logits,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
output_router_logits=output_router_logits,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
@ -1291,6 +1090,7 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
@ -1299,6 +1099,13 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
@ -1341,8 +1148,10 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
cache_position=cache_position,
)
# Only compute necessary logits
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
logits = logits / self.config.logits_scaling
loss = None

View File

@ -16,7 +16,6 @@
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
@ -78,7 +77,7 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
output_router_logits: Optional[bool] = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""

View File

@ -41,7 +41,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
@ -315,6 +315,9 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
config_class = InstructBlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_no_split_modules = [
"InstructBlipQFormerEmbeddings",
@ -339,7 +342,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, InstructBlipVisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, InstructBlipForConditionalGeneration):
elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)):
module.query_tokens.data.zero_()
@ -1274,6 +1277,156 @@ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
)
@add_start_docstrings(
"""
InstructBLIP base Model consisting of language model, qformer and vision encoder.
""",
INSTRUCTBLIP_START_DOCSTRING,
)
class InstructBlipModel(InstructBlipPreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
def __init__(self, config: InstructBlipConfig):
super().__init__(config)
self.vision_model = InstructBlipVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = InstructBlipQFormerModel(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
self.language_model = AutoModel.from_config(config.text_config)
if self.language_model._no_split_modules is not None:
self._no_split_modules.extend(self.language_model._no_split_modules)
if self.language_model._keep_in_fp32_modules is not None:
self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
https://github.com/huggingface/transformers/pull/21707 for more details.
"""
hf_device_map = self.hf_device_map
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
# warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
logger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
" Please pass a `device_map` that contains `language_model` to remove this warning."
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
" more details on creating a `device_map` for large models.",
)
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
@add_start_docstrings_to_model_forward(INSTRUCTBLIP_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.FloatTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
@add_start_docstrings(
"""
InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
@ -1336,11 +1489,13 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
def get_decoder(self):
return self.language_model.get_decoder()
# Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._tie_weights
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
# Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
@ -1645,6 +1800,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
__all__ = [
"InstructBlipQFormerModel",
"InstructBlipPreTrainedModel",
"InstructBlipModel",
"InstructBlipForConditionalGeneration",
"InstructBlipVisionModel",
]

View File

@ -45,7 +45,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_instructblipvideo import (
InstructBlipVideoConfig,
InstructBlipVideoQFormerConfig,
@ -945,6 +945,9 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
config_class = InstructBlipVideoConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_no_split_modules = [
"InstructBlipVideoQFormerEmbeddings",
@ -969,7 +972,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
elif isinstance(module, InstructBlipVideoVisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, InstructBlipVideoForConditionalGeneration):
elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)):
module.query_tokens.data.zero_()
@ -1269,6 +1272,166 @@ class InstructBlipVideoForConditionalGenerationModelOutput(ModelOutput):
)
@add_start_docstrings(
"""
InstructBlipVideo base Model consisting of language model, qformer and vision encoder.
""",
INSTRUCTBLIPVIDEO_START_DOCSTRING,
)
class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)
self.vision_model = InstructBlipVideoVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = InstructBlipVideoQFormerModel(config.qformer_config)
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
self.language_model = AutoModel.from_config(config.text_config)
if self.language_model._no_split_modules is not None:
self._no_split_modules.extend(self.language_model._no_split_modules)
if self.language_model._keep_in_fp32_modules is not None:
self._keep_in_fp32_modules.extend(self.language_model._keep_in_fp32_modules)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
https://github.com/huggingface/transformers/pull/21707 for more details.
"""
hf_device_map = self.hf_device_map
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
# warn users about unexpected behavior when using multi-GPU + InstructBlipVideo + `accelerate`.
logger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
" Please pass a `device_map` that contains `language_model` to remove this warning."
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
" more details on creating a `device_map` for large models.",
)
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.FloatTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# we process in a batched way, later unbatch it back (video has frames=4 always)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipVideoForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
@add_start_docstrings(
"""
InstructBlipVideo Model for generating text given an image and an optional text prompt. The model consists of a vision
@ -1682,5 +1845,6 @@ __all__ = [
"InstructBlipVideoVisionModel",
"InstructBlipVideoPreTrainedModel",
"InstructBlipVideoQFormerModel",
"InstructBlipVideoModel",
"InstructBlipVideoForConditionalGeneration",
]

View File

@ -27,6 +27,7 @@ from transformers.models.instructblip.configuration_instructblip import (
from transformers.models.instructblip.modeling_instructblip import (
InstructBlipForConditionalGeneration,
InstructBlipForConditionalGenerationModelOutput,
InstructBlipModel,
InstructBlipPreTrainedModel,
InstructBlipQFormerModel,
InstructBlipVisionModel,
@ -34,7 +35,7 @@ from transformers.models.instructblip.modeling_instructblip import (
from ...configuration_utils import PretrainedConfig
from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from ...utils import logging
from ...utils import add_start_docstrings_to_model_forward, logging
from ..auto import CONFIG_MAPPING, AutoConfig
@ -191,6 +192,110 @@ class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForCondit
pass
INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING = None
class InstructBlipVideoModel(InstructBlipModel):
@add_start_docstrings_to_model_forward(INSTRUCTBLIPVIDEO_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: torch.FloatTensor,
qformer_input_ids: torch.FloatTensor,
qformer_attention_mask: Optional[torch.LongTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
use_cache: Optional[bool] = None,
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# we process in a batched way, later unbatch it back (video has frames=4 always)
batch_size, frames, channel, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
if qformer_attention_mask is None:
qformer_attention_mask = torch.ones_like(qformer_input_ids)
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0)
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0)
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
query_outputs = self.qformer(
input_ids=qformer_input_ids,
attention_mask=qformer_attention_mask,
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0][:, : query_tokens.size(1), :]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
)
if not return_dict:
return (vision_outputs, query_outputs, outputs)
return InstructBlipVideoForConditionalGenerationModelOutput(
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration):
def forward(
self,
@ -508,5 +613,6 @@ __all__ = [
"InstructBlipVideoVisionModel",
"InstructBlipVideoPreTrainedModel",
"InstructBlipVideoQFormerModel",
"InstructBlipVideoModel",
"InstructBlipVideoForConditionalGeneration",
]

View File

@ -31,7 +31,7 @@ from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
@ -45,7 +45,7 @@ from ...utils import (
replace_return_docstrings,
torch_int,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
@ -608,14 +608,13 @@ INTERNVL_START_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"The bare InternVL Model outputting raw hidden-states without any specific head on top.",
INTERNVL_START_DOCSTRING,
)
class InternVLPreTrainedModel(PreTrainedModel):
config_class = InternVLConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["InternVLVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -654,15 +653,13 @@ class InternVLMultiModalProjector(nn.Module):
@dataclass
class InternVLCausalLMOutputWithPast(ModelOutput):
class InternVLModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for InternVL causal language model (or autoregressive) outputs.
Base class for InternVL outputs, with hidden states and attentions.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
@ -681,15 +678,10 @@ class InternVLCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
@ -771,23 +763,18 @@ INTERNVL_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The INTERNVL model which consists of a vision backbone and a language model.""",
"""The InternVL model which consists of a vision backbone and a language model, without a language modeling head.""",
INTERNVL_START_DOCSTRING,
)
class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin):
class InternVLModel(InternVLPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: InternVLConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = InternVLMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -796,18 +783,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -853,12 +828,221 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return vision_features
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, InternVLModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = InternVLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
Args:
vision_features (`torch.Tensor`):
Input tensor of shape (batch_size, width, height, channels).
scale_factor (`float`, *optional*, defaults to `0.5`):
Factor by which to downsample. Default is 0.5, which halves the dimensions.
Returns:
vision_features (`torch.Tensor`):
Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
"""
batch_size, width, height, channels = vision_features.size()
if height % scale_factor != 0 or width % scale_factor != 0:
raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
# Reshape to allow downsampling
vision_features = vision_features.view(
batch_size, width, int(height * scale_factor), int(channels / scale_factor)
)
# Permute dimensions to align downsampled axis correctly
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
# Reshape to achieve final downsampled dimensions
vision_features = vision_features.view(
batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
)
# Swap height and width back for proper orientation
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
return vision_features
@dataclass
class InternVLCausalLMOutputWithPast(ModelOutput):
"""
Base class for InternVL causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
@add_start_docstrings(
"""The INTERNVL model which consists of a vision backbone and a language model.""",
INTERNVL_START_DOCSTRING,
)
class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: InternVLConfig):
super().__init__(config)
self.model = InternVLModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -872,7 +1056,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, InternVLCausalLMOutputWithPast]:
r"""
@ -925,7 +1109,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
>>> print(processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True))
The images depict the Statue of Liberty and the Golden Gate Bridge.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -940,73 +1123,32 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return InternVLCausalLMOutputWithPast(
loss=loss,
@ -1014,7 +1156,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -1030,7 +1172,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -1047,45 +1189,66 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
return model_inputs
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
vision_features (`torch.Tensor`):
Input tensor of shape (batch_size, width, height, channels).
scale_factor (`float`, *optional*, defaults to `0.5`):
Factor by which to downsample. Default is 0.5, which halves the dimensions.
Returns:
vision_features (`torch.Tensor`):
Downsampled tensor of shape (batch_size, height*scale_factor, width*scale_factor, channels/(scale_factor^2)).
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
batch_size, width, height, channels = vision_features.size()
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if height % scale_factor != 0 or width % scale_factor != 0:
raise ValueError("Height and width must be divisible by scale_factor for proper downsampling.")
# Reshape to allow downsampling
vision_features = vision_features.view(
batch_size, width, int(height * scale_factor), int(channels / scale_factor)
)
# Permute dimensions to align downsampled axis correctly
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
# Reshape to achieve final downsampled dimensions
vision_features = vision_features.view(
batch_size, int(height * scale_factor), int(width * scale_factor), int(channels / (scale_factor**2))
)
# Swap height and width back for proper orientation
vision_features = vision_features.permute(0, 2, 1, 3).contiguous()
return vision_features
return causal_mask
__all__ = [
"InternVLVisionPreTrainedModel",
"InternVLVisionModel",
"InternVLPreTrainedModel",
"InternVLModel",
"InternVLForConditionalGeneration",
]

View File

@ -39,7 +39,12 @@ from ...utils import (
from ..clip.modeling_clip import CLIPMLP
from ..janus.modeling_janus import JanusVisionAttention
from ..llama.modeling_llama import LlamaRMSNorm
from ..llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, LlavaPreTrainedModel
from ..llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
LlavaModel,
LlavaPreTrainedModel,
)
from .configuration_internvl import InternVLConfig, InternVLVisionConfig
@ -573,11 +578,7 @@ class InternVLMultiModalProjector(nn.Module):
return hidden_states
class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
class InternVLModel(LlavaModel):
def pixel_shuffle(self, vision_features: torch.Tensor, scale_factor: float = 0.5):
"""Perform pixel shuffle downsampling on vision features.
@ -658,6 +659,13 @@ class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
return vision_features
class InternVLCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
class InternVLForConditionalGeneration(LlavaForConditionalGeneration):
@can_return_tuple
@add_start_docstrings_to_model_forward(INTERNVL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=InternVLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(**super_kwargs):
@ -718,5 +726,6 @@ __all__ = [
"InternVLVisionPreTrainedModel",
"InternVLVisionModel",
"InternVLPreTrainedModel",
"InternVLModel",
"InternVLForConditionalGeneration",
]

View File

@ -1563,7 +1563,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
inputs_embeds = self.get_input_embeddings()(input_tokens)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.

View File

@ -1378,7 +1378,7 @@ class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
inputs_embeds = self.get_input_embeddings()(input_tokens)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
if model_kwargs.get("past_key_values", None) is None:
# Prepare cache if not provided.

View File

@ -23,16 +23,17 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_llava import LlavaConfig
@ -44,6 +45,39 @@ _CONFIG_FOR_DOC = "LlavaConfig"
_CHECKPOINT_FOR_DOC = "llava-hf/llava-1.5-7b-hf"
@dataclass
class LlavaModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Llava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class LlavaCausalLMOutputWithPast(ModelOutput):
"""
@ -72,7 +106,7 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
@ -124,14 +158,13 @@ LLAVA_START_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"The bare Llava Model outputting raw hidden-states without any specific head on top.",
LLAVA_START_DOCSTRING,
)
class LlavaPreTrainedModel(PreTrainedModel):
config_class = LlavaConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -149,6 +182,9 @@ class LlavaPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
LLAVA_INPUTS_DOCSTRING = r"""
@ -229,23 +265,18 @@ LLAVA_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The LLAVA model which consists of a vision backbone and a language model.""",
"""The Llava model which consists of a vision backbone and a language model, without a language modeling head.""",
LLAVA_START_DOCSTRING,
)
class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
class LlavaModel(LlavaPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: LlavaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.language_model = AutoModel.from_config(config.text_config)
self.post_init()
def get_input_embeddings(self):
@ -254,18 +285,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def get_image_features(
self,
pixel_values: torch.FloatTensor,
@ -277,7 +296,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
@ -312,12 +331,146 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, LlavaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
n_image_tokens = (input_ids == self.config.image_token_id).sum()
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The LLAVA model which consists of a vision backbone and a language model.""",
LLAVA_START_DOCSTRING,
)
class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LlavaConfig):
super().__init__(config)
self.model = LlavaModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
@ -331,7 +484,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
image_sizes: torch.Tensor = None,
**lm_kwargs,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
@ -371,7 +524,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -386,73 +538,32 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
return_dict=True,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaCausalLMOutputWithPast(
loss=loss,
@ -460,7 +571,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -476,7 +587,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -493,5 +604,61 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]

View File

@ -14,12 +14,17 @@
# limitations under the License.
"""Image processor class for LLaVa-NeXT."""
import math
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict, select_best_resolution
from ...image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_patch_output_size,
get_size_dict,
select_best_resolution,
)
from ...image_transforms import (
PaddingMode,
convert_to_rgb,
@ -99,23 +104,6 @@ def expand_to_square(image: np.array, background_color, input_data_format) -> np
return result
def _get_patch_output_size(image, target_resolution, input_data_format):
original_height, original_width = get_image_size(image, channel_dim=input_data_format)
target_height, target_width = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
return new_height, new_width
class LlavaNextImageProcessor(BaseImageProcessor):
r"""
Constructs a LLaVa-NeXT image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
@ -429,7 +417,7 @@ class LlavaNextImageProcessor(BaseImageProcessor):
Returns:
np.array: The resized and padded image.
"""
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
@ -443,12 +431,12 @@ class LlavaNextImageProcessor(BaseImageProcessor):
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = _get_patch_output_size(image, target_resolution, input_data_format)
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y), (paste_x, paste_x)))
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
return padded_image

View File

@ -102,8 +102,8 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list
of the form `(height, width)`.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaNextFastImageProcessorKwargs]) -> BatchFeature:
@ -164,10 +164,10 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
return padded_image

View File

@ -26,16 +26,17 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_llava_next import LlavaNextConfig
@ -139,18 +140,51 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor
@dataclass
class LlavaNextModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Llava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class LlavaNextCausalLMOutputWithPast(ModelOutput):
"""
@ -237,9 +271,9 @@ LLAVA_NEXT_START_DOCSTRING = r"""
)
class LlavaNextPreTrainedModel(PreTrainedModel):
config_class = LlavaNextConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVisionAttention"]
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
@ -254,7 +288,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextForConditionalGeneration):
elif isinstance(module, LlavaNextModel):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
@ -340,10 +374,12 @@ LLAVA_NEXT_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
"""The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""",
LLAVA_NEXT_START_DOCSTRING,
)
class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
class LlavaNextModel(LlavaNextPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def __init__(self, config: LlavaNextConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
@ -353,48 +389,16 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.post_init()
@property
def padding_side(self):
return self._padding_side
@padding_side.setter
def padding_side(self, padding_side: str):
if padding_side not in ["left", "right"]:
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -524,12 +528,152 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, LlavaNextModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaNextModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
return output if return_dict else output.to_tuple()
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_START_DOCSTRING,
)
class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^image_newline": "model.image_newline",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LlavaNextConfig):
super().__init__(config)
self.model = LlavaNextModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -597,45 +741,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
vision_feature_select_strategy=vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
outputs = self.language_model(
outputs = self.model(
input_ids,
pixel_values=pixel_values,
image_sizes=image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
@ -645,33 +756,17 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaNextCausalLMOutputWithPast(
loss=loss,
@ -679,7 +774,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(
@ -696,7 +791,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -714,5 +809,61 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel"]
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["LlavaNextForConditionalGeneration", "LlavaNextPreTrainedModel", "LlavaNextModel"]

View File

@ -30,24 +30,65 @@ from torch import nn
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_llava_next_video import LlavaNextVideoConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlavaNextVideoConfig"
@dataclass
class LlavaNextVideoModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Llava outputs, with hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
video_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size * num_frames, num_videos, sequence_length, hidden_size)`.
video_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
"""
image_hidden_states: Optional[torch.FloatTensor] = None
video_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class LlavaNextVideoCausalLMOutputWithPast(ModelOutput):
"""
@ -167,6 +208,34 @@ LLAVA_NEXT_VIDEO_START_DOCSTRING = r"""
"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
config_class = LlavaNextVideoConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextVideoModel):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
@ -262,14 +331,14 @@ def unpad_image(tensor, original_size):
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
else:
scale_factor = current_height / original_height
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
return unpadded_tensor
@ -355,38 +424,12 @@ LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
"""The Llava-Next model which consists of a vision backbone and a language model without language modeling head.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoPreTrainedModel(PreTrainedModel):
config_class = LlavaNextVideoConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True
class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
def _init_weights(self, module):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LlavaNextVideoForConditionalGeneration):
embed_std = 1 / math.sqrt(self.config.text_config.hidden_size)
module.image_newline.data.normal_(mean=0.0, std=embed_std)
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin):
def __init__(
self,
config: LlavaNextVideoConfig,
@ -399,43 +442,17 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]
self.language_model = AutoModel.from_config(config.text_config)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
self.vision_resampler = LlavaNextVideoPooler(config)
self.post_init()
@property
def padding_side(self):
return self._padding_side
@padding_side.setter
def padding_side(self, padding_side: str):
if padding_side not in ["left", "right"]:
raise ValueError(f"{padding_side} is not `left` or `right`.")
self._padding_side = padding_side
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@ -564,13 +581,222 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**lm_kwargs,
) -> Union[Tuple, LlavaNextVideoModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
self.vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
output = LlavaNextVideoModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
)
return output if return_dict else output.to_tuple()
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_video_features = video_features.hidden_states[vision_feature_layer]
else:
hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_video_features = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_features = selected_video_features
# Same as image features except that video has pooling layer
video_features = self.vision_resampler(selected_video_features)
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(video_features, frames, dim=0)
return video_features
@add_start_docstrings(
"""The LLAVA-NeXT model which consists of a vision backbone and a language model.""",
LLAVA_NEXT_VIDEO_START_DOCSTRING,
)
class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^image_newline": "model.image_newline",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LlavaNextVideoConfig):
super().__init__(config)
self.model = LlavaNextVideoModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Make modules available throught conditional class for BC
@property
def language_model(self):
return self.model.language_model
@property
def vision_tower(self):
return self.model.vision_tower
@property
def multi_modal_projector(self):
return self.model.multi_modal_projector
@can_return_tuple
@add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
pixel_values_videos: torch.FloatTensor = None,
image_sizes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@ -662,117 +888,47 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image shows a red stop sign on a pole, with a traditional Chinese archway (...)"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.vision_feature_layer = (
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
self.vision_feature_select_strategy = (
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
raise ValueError(
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
"and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None and pixel_values.size(0) > 0:
image_features = self.get_image_features(
pixel_values,
image_sizes,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
self.vision_feature_select_strategy,
image_newline=self.image_newline,
)
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
video_features = self.get_video_features(
pixel_values_videos,
vision_feature_layer=self.vision_feature_layer,
vision_feature_select_strategy=self.vision_feature_select_strategy,
)
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
outputs = self.language_model(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
image_sizes=image_sizes,
**lm_kwargs,
)
logits = outputs[0]
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return LlavaNextVideoCausalLMOutputWithPast(
loss=loss,
@ -780,8 +936,8 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
video_hidden_states=video_features if pixel_values_videos is not None else None,
image_hidden_states=outputs.image_hidden_states,
video_hidden_states=outputs.video_hidden_states,
)
def prepare_inputs_for_generation(
@ -799,7 +955,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
):
# Overwritten -- extra custom processing
model_inputs = self.language_model.prepare_inputs_for_generation(
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
@ -818,51 +974,60 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
return model_inputs
def get_video_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Obtains video last hidden states from the vision tower and apply multimodal projection.
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`)
The tensors corresponding to the input video.
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Returns:
video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches
and are of shape `(num_videos, video_length, embed_dim)`).
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
video_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_video_features = video_features.hidden_states[vision_feature_layer]
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
hs_pool = [video_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_video_features = torch.cat(hs_pool, dim=-1)
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
if vision_feature_select_strategy == "default":
selected_video_features = selected_video_features[:, 1:]
elif vision_feature_select_strategy == "full":
selected_video_features = selected_video_features
# Same as image features except that video has pooling layer
video_features = self.vision_resampler(selected_video_features)
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(video_features, frames, dim=0)
return video_features
return causal_mask
__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoPreTrainedModel"]
__all__ = ["LlavaNextVideoForConditionalGeneration", "LlavaNextVideoModel", "LlavaNextVideoPreTrainedModel"]

Some files were not shown because too many files have changed in this diff Show More