Merge branch 'main' into fixing_gptq_tests
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled

This commit is contained in:
Mohamed Mekkouri 2025-05-21 14:27:56 +02:00 committed by GitHub
commit cb7df519b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
116 changed files with 7318 additions and 681 deletions

View File

@ -39,55 +39,100 @@ jobs:
name: ci_results_run_models_gpu
path: /transformers/ci_results_run_models_gpu
- name: Check file
working-directory: /transformers
run: |
if [ -f ci_results_run_models_gpu/new_model_failures.json ]; then
echo "`ci_results_run_models_gpu/new_model_failures.json` exists, continue ..."
echo "process=true" >> $GITHUB_ENV
else
echo "`ci_results_run_models_gpu/new_model_failures.json` doesn't exist, abort."
echo "process=false" >> $GITHUB_ENV
fi
- uses: actions/download-artifact@v4
if: ${{ env.process == 'true' }}
with:
pattern: setup_values*
path: setup_values
merge-multiple: true
- name: Prepare some setup values
if: ${{ env.process == 'true' }}
run: |
if [ -f setup_values/prev_workflow_run_id.txt ]; then
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
if [ -f setup_values/other_workflow_run_id.txt ]; then
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
- name: Update clone
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: git fetch && git checkout ${{ github.sha }}
- name: Get target commit
working-directory: /transformers/utils
if: ${{ env.process == 'true' }}
run: |
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"]); print(commit)')" >> $GITHUB_ENV
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"], workflow_run_id=os.environ["PREV_WORKFLOW_RUN_ID"]); print(commit)')" >> $GITHUB_ENV
- name: Checkout to `start_sha`
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: git fetch && git checkout ${{ inputs.start_sha }}
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .
- name: NVIDIA-SMI
if: ${{ env.process == 'true' }}
run: |
nvidia-smi
- name: Environment
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
python3 utils/print_env.py
- name: Show installed libraries and their versions
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: pip freeze
- name: Check failed tests
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: python3 utils/check_bad_commit.py --start_commit ${{ inputs.start_sha }} --end_commit ${{ env.END_SHA }} --file ci_results_run_models_gpu/new_model_failures.json --output_file new_model_failures_with_bad_commit.json
- name: Show results
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
ls -l new_model_failures_with_bad_commit.json
cat new_model_failures_with_bad_commit.json
- name: Checkout back
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
git checkout ${{ inputs.start_sha }}
- name: Process report
shell: bash
working-directory: /transformers
if: ${{ env.process == 'true' }}
env:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
run: |
python3 utils/process_bad_commit_report.py
@ -95,7 +140,9 @@ jobs:
- name: Process report
shell: bash
working-directory: /transformers
if: ${{ env.process == 'true' }}
env:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
run: |
{
@ -105,7 +152,7 @@ jobs:
} >> "$GITHUB_ENV"
- name: Send processed report
if: ${{ !endsWith(env.REPORT_TEXT, '{}') }}
if: ${{ env.process == 'true' && !endsWith(env.REPORT_TEXT, '{}') }}
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.

View File

@ -8,8 +8,43 @@ on:
push:
branches:
- run_scheduled_ci*
workflow_dispatch:
inputs:
prev_workflow_run_id:
description: 'previous workflow run id to compare'
type: string
required: false
default: ""
other_workflow_run_id:
description: 'other workflow run id to compare'
type: string
required: false
default: ""
# Used for `push` to easily modiffy the target workflow runs to compare against
env:
prev_workflow_run_id: ""
other_workflow_run_id: ""
jobs:
setup:
name: Setup
runs-on: ubuntu-22.04
steps:
- name: Setup
run: |
mkdir "setup_values"
echo "${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}" > "setup_values/prev_workflow_run_id.txt"
echo "${{ inputs.other_workflow_run_id || env.other_workflow_run_id }}" > "setup_values/other_workflow_run_id.txt"
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: setup_values
path: setup_values
model-ci:
name: Model CI
uses: ./.github/workflows/self-scheduled.yml

View File

@ -39,6 +39,21 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
- name: Prepare some setup values
run: |
if [ -f setup_values/prev_workflow_run_id.txt ]; then
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
if [ -f setup_values/other_workflow_run_id.txt ]; then
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
- name: Send message to Slack
if: ${{ inputs.job != 'run_quantization_torch_gpu' }}
env:
@ -50,7 +65,6 @@ jobs:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
CI_EVENT: ${{ inputs.ci_event }}
CI_SHA: ${{ github.sha }}
CI_WORKFLOW_REF: ${{ github.workflow_ref }}
CI_TEST_JOB: ${{ inputs.job }}
SETUP_STATUS: ${{ inputs.setup_status }}
# We pass `needs.setup.outputs.matrix` as the argument. A processing in `notification_service.py` to change
@ -58,7 +72,6 @@ jobs:
# For a job that doesn't depend on (i.e. `needs`) `setup`, the value for `inputs.folder_slices` would be an
# empty string, and the called script still get one argument (which is the emtpy string).
run: |
sudo apt-get install -y curl
pip install huggingface_hub
pip install slack_sdk
pip show slack_sdk
@ -86,7 +99,6 @@ jobs:
# We pass `needs.setup.outputs.quantization_matrix` as the argument. A processing in `notification_service_quantization.py` to change
# `quantization/bnb` to `quantization_bnb` is required, as the artifact names use `_` instead of `/`.
run: |
sudo apt-get install -y curl
pip install huggingface_hub
pip install slack_sdk
pip show slack_sdk

View File

@ -455,6 +455,8 @@
title: Falcon
- local: model_doc/falcon3
title: Falcon3
- local: model_doc/falcon_h1
title: FalconH1
- local: model_doc/falcon_mamba
title: FalconMamba
- local: model_doc/flan-t5

View File

@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
<!---
## Usage Tips
Tips:
Tips:
- The architecture is based on Mamba-2 models.
@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
```
## Padding-Free Training
Bamba supports padding-free training in which distinct training examples can be concatenated
together while nevertheless processing the inputs as though they belonged to separate batches. When
the examples are of varying lengths, padding-free training can provide significant speed ups and
memory savings compared to batching the examples together and using padding, as the unnecessary
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
as the model and the data distribution, but throughput gains up to [~2x are commonly
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
packages, and the following arguments must be passed to the model in addition to `input_ids` and
`labels`:
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
* Each of the [`FlashAttentionKwargs`]
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
* `max_length_q: int`: the longest query length in the batch.
* `max_length_k: int`: the longest key length in the batch.
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
for additional information.
[[autodoc]] BambaForCausalLM
- forward
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).

View File

@ -0,0 +1,65 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# FalconH1
## Overview
The FalconH1 model was developed by the TII Pretraining team. A comprehensive research paper covering the architecture, pretraining dynamics, experimental results, and conclusions is forthcoming. You can read more about this series in [this website](https://github.com/tiiuae/Falcon-H1).
## Contributors
This model was contributed by [DhiyaEddine](https://huggingface.co/DhiyaEddine), [ybelkada](https://huggingface.co/ybelkada), [JingweiZuo](https://huggingface.co/JingweiZuo), [IlyasChahed](https://huggingface.co/IChahed), and [MaksimVelikanov](https://huggingface.co/yellowvm).
The original code can be found [here](https://github.com/tiiuae/Falcon-H1).
## FalconH1Config
| Model | Depth | Dim | Attn Heads | KV | Mamba Heads | d_head | d_state | Ctx Len |
|-----------|--------|------|------------|----|--------------|--------------|------|-----------------|
| H1 0.5B | 36 | 1024 | 8 | 2 | 24 | 64 / 64 | 128 | 4K, 16K-SFT |
| H1 1.5B | 24 | 2048 | 8 | 2 | 48 | 128 / 64 | 256 | 128K |
| H1 1.5B-d | 66 | 1280 | 6 | 2 | 24 | 128 / 64 | 256 | 128K |
| H1 3B | 32 | 2560 | 10 | 2 | 32 | 128 / 128 | 256 | 128K |
| H1 7B | 44 | 3072 | 12 | 2 | 24 | 128 / 128 | 256 | 256K |
| H1 34B | 72 | 5120 | 20 | 4 | 32 | 128 / 128 | 256 | 256K |
[[autodoc]] FalconH1Config
<!---
## Usage Tips
Tips:
- The architecture is based on Mamba-2 models.
## FalconH1Model
[[autodoc]] FalconH1Model
- forward
-->
## FalconH1ForCausalLM
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
message = ["Mamba is a snake with following properties "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
```
[[autodoc]] FalconH1ForCausalLM
- forward
This HF implementation is contributed by [younesbelkada](https://github.com/younesbelkada) and [DhiaEddineRhaiem](https://github.com/dhiaEddineRhaiem).

View File

@ -147,7 +147,7 @@ print(processor.decode(output[0], skip_special_tokens=True))
### Multi image inference
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. Here is how you can do it:
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. For multi-image cases, we recommend using a **nested list of images** as input. Otherwise, every image will be patchified and consume a lot of memory. Here is how you can do it:
```python
import requests

435
examples/3D_parallel.py Normal file
View File

@ -0,0 +1,435 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""":
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
Usage:
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=5,6,7
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
DP_SIZE=2 CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=8 examples/3D_parallel.py
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=localhost:29504 examples/3D_parallel.py
ocalhost:29504 test_train.py
"""
import logging
import os
from contextlib import nullcontext
from typing import Iterable
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.optim as optim
import wandb
from datasets import load_dataset
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
# from torch.distributed.tensor.experimental._attention import set_rotate_method
# set_rotate_method("alltoall") # CP rotate shards using all-to-all
def main():
tp_size = int(os.environ.get("TP_SIZE", 1))
dp_size = int(os.environ.get("DP_SIZE", 1))
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
# sdpa_backend = SDPBackend.MATH # For CP
global_batch_size = 8 # Desired global batch size
seq_len = 1024 # Sequence length
num_train_steps = 10000 # Number of training steps
LR = 1e-5
model_name = "HuggingFaceTB/SmolLM2-1.7B"
# model_name = "unsloth/Llama-3.2-1B"
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
# Initialize distributed environment
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
assert world_size == tp_size * dp_size * cp_size, (
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
)
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"]
dp_mesh = world_mesh["dp"]
cp_mesh = world_mesh["cp"]
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
logger.info(f"Created DeviceMesh: {world_mesh}")
logger.info(
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
)
if dist.get_rank() == 0:
wandb.init(
project="tp_dp_test",
config={
"tp_size": tp_size,
"dp_size": dp_size,
"cp_size": cp_size,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
"lr": LR,
"weight_decay": 0.1,
},
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
if model_name == "unsloth/Llama-3.2-1B"
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
)
logger.info("Wandb initialized.")
# Log the current file to wandb
wandb.save("test_train.py")
# Load model and tokenizer
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh if dist.is_initialized() else None,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
device = torch.device(f"cuda:{local_rank}")
logger.info(f"Using device: {device} for non-model tensors")
use_ddp = False
if dist.is_initialized() and dp_mesh.size() > 1:
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
use_ddp = True
pass
model.train()
logger.info("Loading TinyStories dataset...")
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
def tokenize_function(examples):
# Tokenize the text without padding
tokenized_batch = tokenizer(
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
)
# Set labels to be the same as input_ids for Causal LM
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
# Create packed sequences
def create_packed_sequences(examples):
# Flatten all sequences
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
# Split into sequences of seq_len + 1 (for input + label)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
# Get the full sequence
full_sequence = all_tokens[start_idx:end_idx]
# For input_ids, remove the last token
packed_input_ids.append(full_sequence[:-1])
# For labels, remove the first token
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
# Apply packing to the dataset
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000, # Process in batches for efficiency
num_proc=60,
)
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
# Shuffle the packed dataset
packed_dataset = packed_dataset.shuffle(seed=42)
logger.info("Packed dataset shuffled")
# Calculate local batch size
if dist.is_initialized():
assert global_batch_size % dp_mesh.size() == 0, (
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
)
local_batch_size = global_batch_size // dp_mesh.size()
else:
local_batch_size = global_batch_size
logger.info(
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
)
# Simple collate function since sequences are already packed
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
if dist.is_initialized():
sampler = DistributedSampler(
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
)
else:
sampler = None
dataloader = DataLoader(
packed_dataset,
batch_size=local_batch_size,
sampler=sampler,
shuffle=False,
collate_fn=collate_fn,
pin_memory=True,
)
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
# Training loop
logger.info(f"Starting training for {num_train_steps} steps...")
model.train()
step = 0
while step < num_train_steps:
for batch in dataloader:
if step >= num_train_steps:
break # Exit loop if max steps reached
# Move batch to appropriate device
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
# Add position_ids to batch before CP sharding
batch_size = batch["input_ids"].shape[0]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
batch["position_ids"] = position_ids
from torch.distributed.tensor.experimental._attention import _cp_options
_cp_options.enable_load_balance = False
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
cp_context = (
nullcontext()
if cp_mesh.size() == 1
else context_parallel(
cp_mesh,
buffers=[
batch["input_ids"],
batch["labels"],
batch["position_ids"],
],
buffer_seq_dims=[1, 1, 1],
)
)
with cp_context:
# Pop labels from batch before model forward pass
labels = batch.pop("labels")
outputs = model(**batch) # [mbs, seq_len/cp]
loss = outputs.loss
logits = outputs.logits
# Compute loss with shifted labels
loss = model.loss_function(
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
)
loss.backward()
# all reduce grads across dp_cp if applicable
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
if hasattr(model, "clip_grad_norm_"):
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) # TODO: fix reported gradnorm
else:
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
assert len(list(model.parameters())) > 5, "No parameters found in model. Probably DDP bug.."
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
optimizer.step()
# allreduce loss across cp_dp before logging
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
current_loss = loss.item()
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
)
wandb.log(
{
"train/loss": current_loss,
"train/gradnorm": gradnorm,
"step": step,
"lr": LR,
"GBS": global_batch_size,
}
)
step += 1 # Increment step count
logger.info("Training loop finished.")
# Save model using DCP (only if distributed)
if dist.is_initialized():
state_dict = {"app": AppState(model, optimizer)}
dcp.save(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
else:
# Fallback to regular save for non-distributed case
save_dir = "test_model_nondist"
model.save_pretrained(save_dir, safe_serialization=False)
tokenizer.save_pretrained(save_dir) # Save tokenizer too
logger.info(f"Saved model to {save_dir}")
dist.destroy_process_group()
logger.info("Cleaned up distributed process group")
# Finish wandb run on rank 0
if dist.get_rank() == 0:
wandb.finish()
logger.info("Wandb run finished.")
def all_reduce_grads(model, world_mesh, use_ddp):
"""All reduce gradients across dp_cp if applicable."""
cp_mesh = world_mesh["cp"]
if use_ddp:
# DDP/FSDP takes care of syncing grads
mesh = cp_mesh
else:
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
if dist.is_initialized() and mesh.size() > 1:
for name, param in model.named_parameters():
if param.grad is not None:
# Workaround for cross-mesh communication limitation with DTensor gradients
if isinstance(param.grad, DTensor):
local_grad = param.grad.to_local()
# Ensure grad requires grad for inplace modification checks (might not be needed)
# local_grad = local_grad.detach().requires_grad_(True)
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
local_grad = local_grad / mesh.size()
# Assign averaged grad back - need careful handling if DTensor structure is complex
# This simple assignment might work if the grad structure matches param structure
param.grad = DTensor.from_local(
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
)
else:
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
class AppState(Stateful):
"""Wrapper for checkpointing the Application State including model and optimizer."""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
)
def clip_grad_norm_(
parameters: Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: bool | None = None,
) -> torch.Tensor:
"""
Clip the gradient norm of an iterable of parameters.
"""
# Filter out parameters with no gradients
parameters = [p for p in parameters if p.grad is not None]
assert len(parameters) > 0, "No parameters with gradients found"
# Calculate total norm
if norm_type == float("inf"):
total_norm = max(p.grad.detach().abs().max() for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
# Convert DTensor to local tensor if needed
if isinstance(total_norm, DTensor):
total_norm = total_norm.full_tensor()
# Clip gradients
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
if __name__ == "__main__":
main()

View File

@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -0,0 +1,793 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""":
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
Usage:
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=5,6,7
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=l
ocalhost:29504 test_train.py
"""
import logging
import os
from contextlib import nullcontext
from typing import Dict, Iterable, Optional
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torch.optim as optim
import wandb
from datasets import load_dataset
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils.data import DataLoader, default_collate
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
ignore_sanity_checks = int(os.environ.get("IGNORE_SANITY", 0)) == 1
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
# from torch.distributed.tensor.experimental._attention import set_rotate_method
# set_rotate_method("alltoall") # rotate shards using all-to-all
def main():
tp_size = int(os.environ.get("TP_SIZE", 1))
dp_size = int(os.environ.get("DP_SIZE", 4))
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
# sdpa_backend = SDPBackend.MATH # For CP
global_batch_size = 8 # Desired global batch size
seq_len = 1024 # Sequence length
num_train_steps = 10000 # Number of training steps
LR = 1e-5
model_name = "HuggingFaceTB/SmolLM2-1.7B"
# model_name = "unsloth/Llama-3.2-1B"
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
# Initialize distributed environment
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
assert world_size == tp_size * dp_size * cp_size, (
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
)
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"]
dp_mesh = world_mesh["dp"]
cp_mesh = world_mesh["cp"]
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
logger.info(f"Created DeviceMesh: {world_mesh}")
logger.info(
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
)
if dist.get_rank() == 0:
wandb.init(
project="tp_dp_test",
config={
"tp_size": tp_size,
"dp_size": dp_size,
"cp_size": cp_size,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
"lr": LR,
"weight_decay": 0.1,
},
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
if model_name == "unsloth/Llama-3.2-1B"
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
)
logger.info(f"ignore_sanity_checks is set to: {ignore_sanity_checks}")
logger.info("Wandb initialized.")
# Log the current file to wandb
wandb.save("test_train.py")
else:
logger.info("Running in non-distributed mode. DeviceMesh not applicable.")
rank = 0
world_size = 1
local_rank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.init(
project="tp_dp_test",
config={
"tp_size": 1,
"dp_size": 1,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
},
name="llama_tp1_dp1_nondist" if model_name == "unsloth/Llama-3.2-1B" else "tp1_dp1_nondist",
)
logger.info("Wandb initialized for non-distributed run.")
# Load model and tokenizer
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh if dist.is_initialized() else None,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
if dist.is_initialized():
assert model.config.num_key_value_heads % tp_mesh.size() == 0, (
f"num_key_value_heads={model.config.num_key_value_heads} must be divisible by tp_size={tp_mesh.size()}"
)
device = torch.device(f"cuda:{local_rank}")
else:
model = model.to(device)
logger.info(f"Using device: {device} for non-model tensors")
use_ddp = False
if dist.is_initialized() and dp_mesh.size() > 1:
# FSDP1
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
# FSDP2
# for transformer_block in model.model.layers:
# fully_shard(transformer_block, mesh=dp_mesh, reshard_after_forward=False)
# fully_shard(model.model, mesh=dp_mesh, reshard_after_forward=False)
# DDP
# replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
# assert len(list(model.parameters()))>5, "No parameters found in model. Probably DDP/FSDP bug.." # TODO: we should be cautious abt using model.parameters()
use_ddp = True
model.train()
assert len(list(model.parameters())) > 0, "No parameters found in model. Probably DDP bug.."
assert len([p for p in model.parameters() if p.requires_grad]) > 0, (
"No gradients found in model. Probably DDP bug.."
)
if dist.is_initialized() and not ignore_sanity_checks:
# assert model is replicated across all dp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, dp_mesh)
# assert model is different across tp (only for sharded params)
for name, param in model.named_parameters():
if isinstance(param, DTensor) and param.placements[0].is_shard():
# Only check sharded parameters for non-sync across TP
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
elif isinstance(param, DTensor) and param.placements[0].is_replicate():
# Replicated parameters should be the same across TP
sanity_check_tensor_sync(param, tp_mesh)
# assert model is replicated across cp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, cp_mesh)
# Load and preprocess TinyStories dataset
logger.info("Loading TinyStories dataset...")
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
def tokenize_function(examples):
# Tokenize the text without padding
tokenized_batch = tokenizer(
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
)
# Set labels to be the same as input_ids for Causal LM
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
# Create packed sequences
def create_packed_sequences(examples):
# Flatten all sequences
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
# Split into sequences of seq_len + 1 (for input + label)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
# Get the full sequence
full_sequence = all_tokens[start_idx:end_idx]
# For input_ids, remove the last token
packed_input_ids.append(full_sequence[:-1])
# For labels, remove the first token
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
# Apply packing to the dataset
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000, # Process in batches for efficiency
num_proc=60,
)
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
# Shuffle the packed dataset
packed_dataset = packed_dataset.shuffle(seed=42)
logger.info("Packed dataset shuffled")
# Calculate local batch size
if dist.is_initialized():
assert global_batch_size % dp_mesh.size() == 0, (
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
)
local_batch_size = global_batch_size // dp_mesh.size()
else:
local_batch_size = global_batch_size
logger.info(
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
)
# Simple collate function since sequences are already packed
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
if dist.is_initialized():
sampler = DistributedSampler(
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
)
else:
sampler = None
dataloader = DataLoader(
packed_dataset,
batch_size=local_batch_size,
sampler=sampler,
shuffle=False,
collate_fn=collate_fn,
)
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
# Training loop
logger.info(f"Starting training for {num_train_steps} steps...")
model.train()
step = 0
while step < num_train_steps:
for batch in dataloader:
if step >= num_train_steps:
break # Exit loop if max steps reached
# Move batch to appropriate device
batch = {k: v.to(device) for k, v in batch.items()}
# Sanity checks for batch distribution (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check batch is same across all tp
sanity_check_tensor_sync(batch["input_ids"], tp_mesh)
# check batch is different across dp
sanity_check_tensor_sync(batch["input_ids"], dp_mesh, not_sync=True)
optimizer.zero_grad()
# Add position_ids to batch before CP sharding
batch_size = batch["input_ids"].shape[0]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
batch["position_ids"] = position_ids
from torch.distributed.tensor.experimental._attention import _cp_options
_cp_options.enable_load_balance = False
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
cp_context = (
nullcontext()
if cp_mesh.size() == 1
else context_parallel(
cp_mesh,
buffers=[
batch["input_ids"],
batch["labels"],
batch["position_ids"],
], # TODO: need to add attention mask
buffer_seq_dims=[1, 1, 1],
)
)
with cp_context:
# Pop labels from batch before model forward pass
labels = batch.pop("labels")
outputs = model(**batch) # [mbs, seq_len/cp]
loss = outputs.loss
logits = outputs.logits
# Compute loss with shifted labels
loss = model.loss_function(
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
)
# Sanity checks for logits
if dist.is_initialized() and not ignore_sanity_checks:
# sanity_check_tensor_sync(logits, tp_mesh) # TODO: only true without sequence parallel
sanity_check_tensor_sync(logits, dp_mesh, not_sync=True)
sanity_check_tensor_sync(logits, cp_mesh, not_sync=True)
loss.backward()
# all reduce grads across dp_cp if applicable
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
# Sanity checks for gradients (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check grads are not same across all tp (for sharded grads)
for name, param in model.named_parameters():
if param.grad is not None and isinstance(param.grad, DTensor):
if param.grad.placements[0].is_shard():
sanity_check_tensor_sync(param.grad, tp_mesh, not_sync=True)
elif param.grad.placements[0].is_replicate():
sanity_check_tensor_sync(param.grad, tp_mesh)
# check grads are same across dp
for name, param in model.named_parameters():
if param.grad is not None and dp_mesh.size() > 1:
sanity_check_tensor_sync(param.grad, dp_mesh)
# check grads are same across cp
for name, param in model.named_parameters():
if param.grad is not None and cp_mesh.size() > 1:
sanity_check_tensor_sync(param.grad, cp_mesh)
# Calculate gradient norm and clip gradients
if hasattr(model, "clip_grad_norm_"):
# when using FSDP or DDP, model.parameters() doesn't work
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0)
else:
assert len(list(model.parameters())) > 2, "No parameters found in model. Probably DDP bug.."
assert len([p for p in model.parameters() if p.requires_grad]) > 2, (
"No gradients found in model. Probably DDP bug.."
)
assert len([p for p in model.parameters() if p.grad is not None]) > 2, (
"No gradients found in model. Probably DDP bug.."
)
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
optimizer.step()
# Sanity checks for updated model parameters (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check updated model is different across all tp (for sharded params)
for name, param in model.named_parameters():
if isinstance(param, DTensor):
if param.placements[0].is_shard():
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
elif param.placements[0].is_replicate():
sanity_check_tensor_sync(param, tp_mesh)
# check updated model is same across dp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, dp_mesh)
# check updated model is same across cp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, cp_mesh)
# allreduce loss across cp_dp before logging
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
current_loss = loss.item()
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
)
wandb.log(
{
"train/loss": current_loss,
"train/gradnorm": gradnorm,
"step": step,
"lr": LR,
"GBS": global_batch_size,
}
)
step += 1 # Increment step count
logger.info("Training loop finished.")
# Save model using DCP (only if distributed)
if dist.is_initialized():
state_dict = {"app": AppState(model, optimizer)}
dcp.save(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
else:
# Fallback to regular save for non-distributed case
save_dir = "test_model_nondist"
model.save_pretrained(save_dir, safe_serialization=False)
tokenizer.save_pretrained(save_dir) # Save tokenizer too
logger.info(f"Saved model to {save_dir}")
# Example of loading the checkpoint (only if distributed)
if dist.is_initialized():
# Create a new model instance
logger.info("Creating new model instance for verification")
new_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh,
torch_dtype=torch.bfloat16, # Use same dtype
)
new_optimizer = optim.AdamW(new_model.parameters(), lr=LR)
# Load checkpoint into new model
state_dict = {"app": AppState(new_model, new_optimizer)}
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info("Loaded checkpoint into new model")
# Verify model weights match
logger.info("Verifying model weights match...")
for (name1, param1), (name2, param2) in zip(model.named_parameters(), new_model.named_parameters()):
torch.testing.assert_close(
param1.to_local(),
param2.to_local(),
rtol=1e-3,
atol=1e-3,
msg=f"Weights mismatch in {name1} vs {name2}",
)
# Verify optimizer states match
logger.info("Verifying optimizer states match...")
for name1, state1 in optimizer.state_dict().items():
state2 = new_optimizer.state_dict()[name1]
if name1 == "state":
# Compare state dictionaries for each parameter
for param_id, param_state1 in state1.items():
param_state2 = state2[param_id]
# Compare each state component (step, exp_avg, exp_avg_sq)
for key, value1 in param_state1.items():
value2 = param_state2[key]
if isinstance(value1, DTensor):
# Convert DTensors to local tensors for comparison
torch.testing.assert_close(
value1.to_local(),
value2.to_local(),
rtol=1e-5,
atol=1e-5,
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
)
else:
torch.testing.assert_close(
value1,
value2,
rtol=1e-5,
atol=1e-5,
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
)
elif name1 == "param_groups":
# Compare param_groups (excluding the actual params list)
for i, (group1, group2) in enumerate(zip(state1, state2)):
for key in group1:
if key != "params": # Skip comparing the params list
assert group1[key] == group2[key], f"Param group mismatch in param_groups[{i}][{key}]"
# Run a forward pass with both models to verify outputs match
logger.info("Running forward pass verification...")
with torch.no_grad():
# Use the last batch for verification
batch = {k: v.to(device) for k, v in batch.items()} # Ensure batch is on correct device
original_outputs = model(**batch)
new_outputs = new_model(**batch)
torch.testing.assert_close(
original_outputs.logits.to_local(),
new_outputs.logits.to_local(),
rtol=1e-3,
atol=1e-3,
msg="Model outputs do not match!",
) # Increased tolerance slightly for bf16
# Clean up distributed environment and finish wandb run
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Cleaned up distributed process group")
# Finish wandb run on rank 0
if dist.get_rank() == 0:
wandb.finish()
logger.info("Wandb run finished.")
else:
wandb.finish()
logger.info("Wandb run finished.")
def all_reduce_grads(model, world_mesh, use_ddp):
"""All reduce gradients across dp_cp if applicable."""
cp_mesh = world_mesh["cp"]
if use_ddp:
# DDP takes care of syncing grads
mesh = cp_mesh
else:
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
if dist.is_initialized() and mesh.size() > 1:
for name, param in model.named_parameters():
if param.grad is not None:
# Workaround for cross-mesh communication limitation with DTensor gradients
if isinstance(param.grad, DTensor):
local_grad = param.grad.to_local()
# Ensure grad requires grad for inplace modification checks (might not be needed)
# local_grad = local_grad.detach().requires_grad_(True)
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
local_grad = local_grad / mesh.size()
# Assign averaged grad back - need careful handling if DTensor structure is complex
# This simple assignment might work if the grad structure matches param structure
param.grad = DTensor.from_local(
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
)
else:
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
class ContextParallelCollator:
"""Collator for context parallel training that splits sequences into chunks."""
def __init__(self, cp_mesh: Optional[DeviceMesh] = None):
self.cp_mesh = cp_mesh
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = default_collate(batch)
if self.cp_mesh is not None and self.cp_mesh.size() > 1:
# Get sequence length from the input batch
seq_len = batch["input_ids"].shape[1]
assert seq_len % self.cp_mesh.size() == 0, (
f"Sequence length {seq_len} must be divisible by CP size {self.cp_mesh.size()}"
)
chunk_size = seq_len // self.cp_mesh.size()
cp_rank = self.cp_mesh.get_local_rank()
start_idx = cp_rank * chunk_size
end_idx = start_idx + chunk_size
# Keep only the local chunk of the sequence
batch["input_ids"] = batch["input_ids"][:, start_idx:end_idx]
batch["attention_mask"] = batch["attention_mask"][:, start_idx:end_idx]
batch["labels"] = batch["labels"][:, start_idx:end_idx]
return batch
class AppState(Stateful):
"""Wrapper for checkpointing the Application State including model and optimizer."""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
)
def sanity_check_tensor_sync(
tensor: torch.Tensor, mesh: DeviceMesh, rtol: float = 1e-4, atol: float = 1e-4, not_sync: bool = False
) -> None:
"""
Verify that a tensor is synchronized (or not synchronized) across all processes in the mesh's process group.
Handles both regular tensors and DTensors.
Args:
tensor (torch.Tensor): The tensor to check for synchronization (can be DTensor)
mesh (DeviceMesh): The device mesh containing the process group
rtol (float): Relative tolerance for comparison
atol (float): Absolute tolerance for comparison
not_sync (bool): If True, asserts that tensors are NOT synchronized. If False, asserts they are synchronized.
"""
if not dist.is_initialized() or mesh.size() == 1:
return # No need to check in non-distributed mode
# Get the process group from the mesh
pg = mesh.get_group()
# Convert DTensor to local tensor if needed
if hasattr(tensor, "to_local"):
local_tensor = tensor.to_local()
else:
local_tensor = tensor
# Gather tensors from all processes
world_size = dist.get_world_size(pg)
gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, local_tensor, group=pg)
# Compare each tensor with the first one
for i in range(1, world_size):
try:
torch.testing.assert_close(gathered_tensors[0], gathered_tensors[i], rtol=rtol, atol=atol)
except AssertionError as e:
if not_sync:
continue
# # Add detailed debugging for logit synchronization issues
# print(f"\nLogit synchronization error between rank 0 and rank {i}:")
# print(f"Tensor shape: {gathered_tensors[0].shape}")
# print(f"Number of mismatched elements: {(gathered_tensors[0] != gathered_tensors[i]).sum()}")
# print(f"Percentage of mismatched elements: {((gathered_tensors[0] != gathered_tensors[i]).sum() / gathered_tensors[0].numel() * 100):.2f}%")
# # Find the first few mismatches
# mismatches = torch.nonzero(gathered_tensors[0] != gathered_tensors[i])
# print("\nFirst few mismatches:")
# for idx in mismatches[:5]:
# idx = tuple(idx.tolist())
# print(f"Index {idx}:")
# print(f"Rank 0 value: {gathered_tensors[0][idx]}")
# print(f"Rank {i} value: {gathered_tensors[i][idx]}")
# print(f"Absolute difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx])}")
# print(f"Relative difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx]) / max(abs(gathered_tensors[0][idx]), abs(gathered_tensors[i][idx]))}")
# # Check if differences are systematic (e.g., all positive or negative)
# diff = gathered_tensors[0] - gathered_tensors[i]
# print(f"\nDifference statistics:")
# print(f"Mean difference: {diff.mean()}")
# print(f"Std difference: {diff.std()}")
# print(f"Max positive difference: {diff.max()}")
# print(f"Max negative difference: {diff.min()}")
raise e
def clip_grad_norm_(
parameters: Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: bool | None = None,
) -> torch.Tensor:
"""
Clip the gradient norm of an iterable of parameters.
"""
# Filter out parameters with no gradients
parameters = [p for p in parameters if p.grad is not None]
assert len(parameters) > 0, "No parameters with gradients found"
# Calculate total norm
if norm_type == float("inf"):
total_norm = max(p.grad.detach().abs().max() for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
# Convert DTensor to local tensor if needed
if isinstance(total_norm, DTensor):
total_norm = total_norm.full_tensor()
# Clip gradients
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
def check_params_sync(model_params, original_params):
"""
Check if original_params are being updated in sync with model parameters.
Args:
model_params: Iterator of model parameters after update
original_params: List of original parameters before DDP wrapping
"""
for mp, op in zip(model_params, original_params):
if isinstance(mp, DTensor):
mp = mp.to_local()
if isinstance(op, DTensor):
op = op.to_local()
if not torch.allclose(mp.data, op.data, rtol=0, atol=0):
raise RuntimeError(f"Parameters out of sync: model param {mp.data} != original param {op.data}")
return True
def get_parameters(model: nn.Module) -> Iterable[torch.Tensor]:
"""
Get all parameters from a model by iterating over its modules.
This is an alternative to model.parameters() that works with DTensor models.
Args:
model (nn.Module): The model to get parameters from
Returns:
Iterable[torch.Tensor]: An iterator over all parameters in the model
"""
for name, module in model._modules.items():
# Look for parameters in module attributes
for attr_name, attr in module.__dict__.items():
if isinstance(attr, torch.Tensor) and attr.requires_grad:
yield attr
# Recursively get parameters from submodules
for param in get_parameters(module):
yield param
def update_model_parameters(model: nn.Module) -> None:
"""
Update model._parameters using named_modules() to ensure all parameters are properly tracked.
Args:
model (nn.Module): The model to update parameters for
"""
# Clear existing parameters
model._parameters = {}
# Add parameters from named_modules
for name, module in model.named_modules():
# Skip the root module itself
if name == "":
continue
# Get the parameter name by removing 'module.' prefix if it exists
param_name = name.replace("module.", "")
# Add weight and bias parameters if they exist
if hasattr(module, "weight") and module.weight is not None:
model._parameters[f"{param_name}.weight"] = module.weight
if hasattr(module, "bias") and module.bias is not None:
model._parameters[f"{param_name}.bias"] = module.bias
if __name__ == "__main__":
main()

View File

@ -44,7 +44,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -0,0 +1,94 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss
world_size = int(os.environ.get("WORLD_SIZE", "1"))
cp_mesh = init_device_mesh("cuda", (world_size,))
rank = torch.distributed.get_node_local_rank()
device = "cuda"
dtype = torch.bfloat16
sdpa_backend = SDPBackend.FLASH_ATTENTION
# prepare inputs
batch_size = 1
seq_len = 128
input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device)
ignore_index = -100
# When using CP, we need to use `shift_labels`
shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index)
shift_labels = shift_labels[..., 1:].contiguous()
position_ids = (
torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1
)
# sync input as they are created randomly
dist.broadcast(input_ids, src=0)
dist.broadcast(shift_labels, src=0)
dist.broadcast(position_ids, src=0)
# model and optimizer
repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model.train()
model.zero_grad()
optimizer.zero_grad()
# For loss
vocab_size = model.config.vocab_size
# so training could be synced
model = DDP(model, device_ids=[rank])
# prepare for CP
buffers = (input_ids, shift_labels, position_ids)
buffer_seq_dims = (1, 1, 1)
# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`.
# no_restore_buffers = set(buffers)
no_restore_buffers = None
# run with CP
with sdpa_kernel(sdpa_backend):
with context_parallel(
cp_mesh,
buffers=buffers,
buffer_seq_dims=buffer_seq_dims,
no_restore_buffers=no_restore_buffers,
):
outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)
print(outputs.logits.shape)
# So far we need to compute `loss` outside `model.forward` when using `shift_labels`
# loss = outputs.loss
loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)
# This could be outside `context_parallel` context if `no_restore_buffers` is specified
loss.backward()
optimizer.step()

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -42,7 +42,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -47,7 +47,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -52,7 +52,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
task_to_keys = {
"cola": ("sentence", None),

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -451,7 +451,7 @@ install_requires = [
setup(
name="transformers",
version="4.52.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.53.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.52.0.dev0"
__version__ = "4.53.0.dev0"
from pathlib import Path
from typing import TYPE_CHECKING

View File

@ -1985,7 +1985,9 @@ class GenerationMixin:
instantiated, writes it to `model_kwargs`, under the name expected by the model.
"""
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
cache_name = "past_key_values" if not is_hybrid_cache else "cache_params"
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)

View File

@ -142,7 +142,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["tensor_parallel"] = [
"shard_and_distribute_module",
"SUPPORTED_TP_STYLES",
"ALL_PARALLEL_STYLES",
"translate_to_torch_parallel_style",
]
try:
@ -271,7 +271,7 @@ if TYPE_CHECKING:
pass
else:
from .tensor_parallel import (
SUPPORTED_TP_STYLES,
ALL_PARALLEL_STYLES,
shard_and_distribute_module,
translate_to_torch_parallel_style,
)

View File

@ -362,8 +362,8 @@ def _replace_with_bitnet_linear(
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
use_rms_norm=quantization_config.use_rms_norm,
rms_norm_eps=quantization_config.rms_norm_eps,
use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
)
model._modules[name].requires_grad_(False)
has_been_replaced = True

View File

@ -13,11 +13,15 @@
# limitations under the License.
from __future__ import annotations
import operator
import os
import re
from functools import lru_cache, partial
from typing import List, Optional, Tuple, Union
from collections.abc import MutableMapping
from functools import partial, reduce
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch import nn
from ..utils import is_torch_greater_or_equal, logging
@ -35,6 +39,56 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
def initialize_tensor_parallelism(tp_plan, tp_size=None):
r"""
Sets up the device mesh and initilized the backend for tensor parallelism.
This function is called when the model is loaded and the TP plan is set to 'auto'.
"""
if tp_plan is None:
return None, None, None
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("Tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "ccl", "hpu": "hccl"}
backend = backend_map.get(device_type)
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", 0)):
backend = "ccl"
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
current_device = getattr(torch, device_type)
if device_type != "cpu":
current_device.set_device(local_rank)
except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed. Make "
"sure you init torch distributed in your script to use `tp_plan='auto'`."
) from e
index = current_device.current_device() if device_type != "cpu" else None
tp_device = torch.device(device_type, index)
# Silence output for non-primary ranks
if index is not None and index > 0:
import sys
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
device_map = tp_device
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
return tp_device, device_map, device_mesh
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""
Convert block count or proportions to block sizes.
@ -220,18 +274,38 @@ def repack_weights(
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
if dim == 0:
size_ = empty_param.shape[0]
param = param[rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), ...]
elif dim == 1 or dim == -2:
size_ = empty_param.shape[-2]
param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), :]
elif dim == 2 or dim == -1:
size_ = empty_param.shape[-1]
param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size())]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
return param
"""
Generalized tensor sharding across a multi-dimensional device mesh.
Args:
param (torch.Tensor): The tensor to shard.
empty_param (torch.Tensor): A tensor used for shape reference.
device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
rank (int): Global rank of the current process/device.
dim (int): Dimension along which to shard the tensor.
"""
param_dim = empty_param.dim()
if dim < 0:
dim = param_dim + dim
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
# Flatten the mesh to get the total number of devices
mesh_shape = device_mesh.shape
world_size = reduce(operator.mul, mesh_shape)
if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
shard_size = empty_param.shape[dim] // world_size
start = rank * shard_size
end = start + shard_size
# Construct slicing index dynamically
slice_indices = [slice(None)] * param_dim
slice_indices[dim] = slice(start, end)
return param[tuple(slice_indices)]
def distribute_module(
@ -339,6 +413,41 @@ class IsolatedParallel(TensorParallelLayer):
)
class ReplicateParallel(TensorParallelLayer):
"""
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
"""
def __init__(self, *, use_dtensor=True, use_local_output=True):
super().__init__()
self.input_layouts = (Replicate(),)
self.output_layouts = (Replicate(),)
self.desired_input_layouts = (Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
# TODO: figure out dynamo support for instance method and switch this to instance method
# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
return outputs.to_local() if use_local_output else outputs
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
param = param[...].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
return param
class ColwiseParallel(TensorParallelLayer):
"""
General tensor parallel layer for transformers.
@ -611,52 +720,62 @@ class SequenceParallel(TensorParallelLayer):
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
SUPPORTED_TP_STYLES = {
"colwise",
"rowwise",
"colwise_rep",
"rowwise_rep",
"local_colwise",
"local_rowwise",
"local",
"gather",
"local_packed_rowwise",
"sequence_parallel",
}
@lru_cache
def translate_to_torch_parallel_style(style: str):
class ParallelInterface(MutableMapping):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
elif style == "local_colwise":
return ColwiseParallel(use_dtensor=False)
elif style == "local_rowwise":
return RowwiseParallel(use_dtensor=False)
elif style == "local":
return IsolatedParallel()
elif style == "gather":
return GatherParallel()
elif style == "local_packed_rowwise":
return PackedRowwiseParallel(use_dtensor=False)
elif style == "sequence_parallel":
return SequenceParallel()
else:
raise ValueError(f"Unsupported parallel style value: {style}")
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
# a new instance is created (in order to locally override a given function)
_global_mapping = {
"colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(),
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
"local_colwise": ColwiseParallel(use_dtensor=False),
"local_rowwise": RowwiseParallel(use_dtensor=False),
"local": IsolatedParallel(),
"gather": GatherParallel(),
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
"sequence_parallel": SequenceParallel(),
"replicate": ReplicateParallel(),
}
def __init__(self):
self._local_mapping = {}
def __getitem__(self, key):
# First check if instance has a local override
if key in self._local_mapping:
return self._local_mapping[key]
return self._global_mapping[key]
def __setitem__(self, key, value):
# Allow local update of the default functions without impacting other instances
self._local_mapping.update({key: value})
def __delitem__(self, key):
del self._local_mapping[key]
def __iter__(self):
# Ensure we use all keys, with the overwritten ones on top
return iter({**self._global_mapping, **self._local_mapping})
def __len__(self):
return len(self._global_mapping.keys() | self._local_mapping.keys())
@classmethod
def register(cls, key: str, value: Callable):
cls._global_mapping.update({key: value})
def valid_keys(self) -> List[str]:
return list(self.keys())
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
def convert_local_tensor_to_dtensor(
@ -722,13 +841,15 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr
# 1. We add hooks to the layer being loaded:
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
try:
tp_layer.prepare_module_tp(module, device_mesh)
except NotImplementedError as e:
print(
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
)
module._hf_tp_plan = current_module_plan
module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
# 2. We add hooks to the parent module if needed
if "." in layer_name:
@ -736,9 +857,11 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr
generic_name = re.sub(r"\d+", "*", parent_layer_name)
# The module itself needs hooks
if module_plan := tp_plan.get(generic_name, False):
tp_layer = translate_to_torch_parallel_style(module_plan)
tp_layer = ALL_PARALLEL_STYLES[module_plan]
module_to_tp_ = model.get_submodule(parent_layer_name)
tp_layer.prepare_module_tp(module_to_tp_, device_mesh)
module_to_tp_._hf_tp_plan = current_module_plan
module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}"
def shard_and_distribute_module(
@ -760,28 +883,29 @@ def shard_and_distribute_module(
current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
if current_module_plan is None:
current_module_plan = "replicate"
if dist.get_rank() == 0:
logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.")
else:
if dist.get_rank() == 0:
logger.info(f"Tensor parallel plan for {param_name}: {current_module_plan}")
# Add hooks to the module if not done yet
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
if not getattr(module_to_tp, "_is_hooked", False):
add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
module_to_tp._is_hooked = True
if current_module_plan is not None:
try:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
else:
# TODO log no plan modules in set
# print("No plan for", parameter_name,end ="\n")
param = param[...].to(param_casting_dtype)
if is_contiguous:
param = param.contiguous()
try:
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
# SUPER IMPORTANT we have to use setattr
# otherwise loading is crazy slow

View File

@ -66,7 +66,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation"
" instructions."
)
raise
@ -360,7 +360,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation"
" instructions."
)
raise

View File

@ -62,8 +62,9 @@ from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
ALL_PARALLEL_STYLES,
_get_parameter_tp_plan,
initialize_tensor_parallelism,
repack_weights,
replace_state_dict_local_with_dtensor,
shard_and_distribute_module,
@ -797,7 +798,7 @@ def _load_state_dict_into_meta_model(
param_name,
casting_dtype,
to_contiguous,
int(os.environ["RANK"]), # the rank
device_mesh.get_local_rank(),
device_mesh,
)
else:
@ -1964,9 +1965,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
if v not in SUPPORTED_TP_STYLES:
if v not in ALL_PARALLEL_STYLES:
raise ValueError(
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
)
def dequantize(self):
@ -3559,6 +3560,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
if safe_serialization:
# TODO: fix safe_serialization for tied weights
# Safetensors does not allow tensor aliasing.
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
@ -4040,6 +4042,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
tp_size (`str`, *optional*):
A torch tensor parallel degree. If not provided would default to world size.
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
@ -4137,6 +4141,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
device_mesh = kwargs.pop("device_mesh", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
@ -4172,59 +4177,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device
device_mesh = None
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if device_type == "cuda":
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
elif device_type == "xpu":
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "hpu":
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e
# Get device with index assuming equal number of devices per host
if device_type == "xpu":
index = torch.xpu.current_device()
elif device_type == "hpu":
index = torch.hpu.current_device()
if device_mesh is None and tp_plan is not None:
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
else:
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)
if index is not None and index > 0:
import sys
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world when tp_size not provided
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
# TODO: make device_mesh support multiple dimensions
if device_mesh.ndim == 1:
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
if use_auth_token is not None:
warnings.warn(
@ -5142,7 +5102,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
name,
casting_dtype,
to_contiguous,
os.environ["RANK"],
device_mesh.get_local_rank(),
device_mesh,
)

View File

@ -103,6 +103,7 @@ if TYPE_CHECKING:
from .ernie import *
from .esm import *
from .falcon import *
from .falcon_h1 import *
from .falcon_mamba import *
from .fastspeech2_conformer import *
from .flaubert import *

View File

@ -364,19 +364,23 @@ class AriaImageProcessor(BaseImageProcessor):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
) -> np.array:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
padded_image = self.pad(image, padding=padding)
return padded_image

View File

@ -748,19 +748,23 @@ class AriaImageProcessor(BaseImageProcessor):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
) -> np.array:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
padded_image = self.pad(image, padding=padding)
return padded_image

View File

@ -118,6 +118,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"),
("falcon", "FalconConfig"),
("falcon_h1", "FalconH1Config"),
("falcon_mamba", "FalconMambaConfig"),
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
("flaubert", "FlaubertConfig"),
@ -481,6 +482,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("esm", "ESM"),
("falcon", "Falcon"),
("falcon3", "Falcon3"),
("falcon_h1", "FalconH1"),
("falcon_mamba", "FalconMamba"),
("fastspeech2_conformer", "FastSpeech2Conformer"),
("flan-t5", "FLAN-T5"),

View File

@ -115,6 +115,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("ernie_m", "ErnieMModel"),
("esm", "EsmModel"),
("falcon", "FalconModel"),
("falcon_h1", "FalconH1Model"),
("falcon_mamba", "FalconMambaModel"),
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
("flaubert", "FlaubertModel"),
@ -558,6 +559,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3ForCausalLM"),
("ernie", "ErnieForCausalLM"),
("falcon", "FalconForCausalLM"),
("falcon_h1", "FalconH1ForCausalLM"),
("falcon_mamba", "FalconMambaForCausalLM"),
("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"),

View File

@ -24,7 +24,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple, Union
from functools import partial
from typing import Callable, Optional, Tuple, TypedDict, Union
import torch
from torch import nn
@ -61,6 +62,31 @@ else:
logger = logging.get_logger(__name__)
class BambaFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
"""
@ -487,6 +513,7 @@ class BambaMixer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
):
# 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
@ -569,7 +596,7 @@ class BambaMixer(nn.Module):
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
@ -610,6 +637,7 @@ class BambaMixer(nn.Module):
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2)
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
@ -629,7 +657,7 @@ class BambaMixer(nn.Module):
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=None,
seq_idx=seq_idx,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
@ -863,9 +891,15 @@ class BambaMixer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
if seq_idx is not None:
raise NotImplementedError(
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
@ -939,7 +973,7 @@ class BambaDecoderLayer(nn.Module):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
**kwargs: Unpack[BambaFlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -959,8 +993,8 @@ class BambaDecoderLayer(nn.Module):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states
@ -974,6 +1008,7 @@ class BambaDecoderLayer(nn.Module):
cache_params=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
**kwargs,
)
self_attn_weights = None
elif self.layer_type == "attention":
@ -1076,7 +1111,7 @@ class BambaModel(BambaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
**kwargs: Unpack[BambaFlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -1128,7 +1163,7 @@ class BambaModel(BambaPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **kwargs),
hidden_states,
layer_mask,
position_ids,
@ -1148,6 +1183,7 @@ class BambaModel(BambaPreTrainedModel):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -19,7 +19,8 @@
# limitations under the License.
"""PyTorch Bamba model."""
from typing import Optional, Tuple, Union
from functools import partial
from typing import Optional, Tuple, TypedDict, Union
import torch
import torch.utils.checkpoint
@ -46,7 +47,12 @@ from transformers.models.mamba2.modeling_mamba2 import (
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, logging
from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
can_return_tuple,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig
@ -71,6 +77,31 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_c
logger = logging.get_logger(__name__)
class BambaFlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
Use cases include padding-free training and fewer `torch.compile` graph breaks.
Attributes:
cu_seq_lens_q (`torch.LongTensor`)
Gets cumulative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`)
Gets cumulative sequence length for key state.
max_length_q (`int`):
Maximum sequence length for query state.
max_length_k (`int`):
Maximum sequence length for key state.
seq_idx (`torch.IntTensor):
Index of each packed sequence.
"""
cu_seq_lens_q: torch.LongTensor
cu_seq_lens_k: torch.LongTensor
max_length_q: int
max_length_k: int
seq_idx: torch.IntTensor
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
"""
@ -278,6 +309,7 @@ class BambaMixer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
):
# 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
@ -360,7 +392,7 @@ class BambaMixer(nn.Module):
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
@ -401,6 +433,7 @@ class BambaMixer(nn.Module):
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2)
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
@ -420,7 +453,7 @@ class BambaMixer(nn.Module):
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=None,
seq_idx=seq_idx,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
@ -654,9 +687,15 @@ class BambaMixer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
if seq_idx is not None:
raise NotImplementedError(
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
@ -701,7 +740,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
**kwargs: Unpack[BambaFlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@ -721,8 +760,8 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
padding-free training and/or improve torch.compile performance.
"""
residual = hidden_states
@ -736,6 +775,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
cache_params=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
**kwargs,
)
self_attn_weights = None
elif self.layer_type == "attention":
@ -838,7 +878,7 @@ class BambaModel(BambaPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
**kwargs: Unpack[BambaFlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -890,7 +930,7 @@ class BambaModel(BambaPreTrainedModel):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
partial(decoder_layer.__call__, **kwargs),
hidden_states,
layer_mask,
position_ids,
@ -910,6 +950,7 @@ class BambaModel(BambaPreTrainedModel):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]

View File

@ -0,0 +1,27 @@
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_falcon_h1 import *
from .modeling_falcon_h1 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,283 @@
# coding=utf-8
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FalconH1 model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class FalconH1Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a
FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration
with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf).
The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 128000):
Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`FalconH1Model`]
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
significantly.
pad_token_id (`int`, *optional*, defaults to 0):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
max_position_embeddings (`int`, *optional*, defaults to 8192):
Max cached sequence length for the model
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mamba_d_ssm (`int`, *optional*, defaults to 1024):
The dimension of the SSM state space latents.
mamba_n_heads (`int`, *optional*, defaults to 128):
The number of mamba heads used in the v2 implementation.
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
Head embeddding dimension size
mamba_n_groups (`int`, *optional*, defaults to 1):
The number of the mamba groups used in the v2 implementation.
mamba_d_state (`int`, *optional*, defaults to 256):
The dimension the mamba state space latents
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
mamba_chunk_size (`int`, *optional*, defaults to 256):
The chunks in which to break the sequence when doing prefill/training
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
mamba_norm_before_gate (`bool`, *optional*, defaults to `True`):
Whether to use RMSNorm before the gate in the Mamba block
mamba_rms_norm (`bool`, *optional*, defaults to `False`):
Whether to use RMSNorm instead of LayerNorm in the Mamba block
projectors_bias (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block
rope_theta (`float`, *optional*, defaults to 100000.0):
The theta value used for the RoPE embeddings.
rope_scaling (`float`, *optional*):
The scaling value used for the RoPE embeddings. If `None`, no scaling is applied.
lm_head_multiplier (`float`, *optional*, defaults to 1.0):
The multiplier for the LM head. This is used to scale the output of the LM head.
embedding_multiplier (`float`, *optional*, defaults to 1.0):
The multiplier for the embedding layer. This is used to scale the output of the embedding layer.
mlp_multipliers (`List[float]`, *optional*):
The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is
the multiplier of gate layer, the second value is the multiplier of the down_proj layer.
key_multiplier (`float`, *optional*):
The multiplier for the key layer. This is used to scale the output of the key layer.
attention_out_multiplier (`float`, *optional*):
The multiplier for the attention output layer. This is used to scale the output of the attention output
attention_in_multiplier (`float`, *optional*):
The multiplier for the attention input layer. This is used to scale the output of the attention input layer.
ssm_multipliers (`List[float]`, *optional*):
The multipliers for the SSM layers. This is used to scale the output of the SSM layers.
ssm_in_multiplier (`float`, *optional*):
The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer.
ssm_out_multiplier (`float`, *optional*):
The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer.
"""
model_type = "falcon_h1"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=128000,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
num_logits_to_keep=1,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
max_position_embeddings=8192,
attention_dropout=0.0,
mamba_d_ssm=1024,
mamba_n_heads=128,
mamba_d_head="auto",
mamba_n_groups=1,
mamba_d_state=256,
mamba_d_conv=4,
mamba_expand=2,
mamba_chunk_size=256,
mamba_conv_bias=True,
mamba_proj_bias=False,
mamba_norm_before_gate=True,
mamba_rms_norm=False,
projectors_bias=False,
rope_theta=100000.0,
rope_scaling=None,
lm_head_multiplier=1.0,
embedding_multiplier=1.0,
mlp_multipliers=None,
key_multiplier=None,
attention_out_multiplier=None,
attention_in_multiplier=None,
ssm_multipliers=None,
ssm_in_multiplier=None,
ssm_out_multiplier=None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
self.attention_bias = False
self.mlp_bias = False
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.rope_theta = rope_theta
self.rope_scaling = None
self.rope_scaling = rope_scaling
self.projectors_bias = projectors_bias
mamba_intermediate = mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
if mamba_intermediate % mamba_n_heads != 0:
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
# for the mamba_v2, must satisfy the following
if mamba_d_head == "auto":
mamba_d_head = mamba_intermediate // mamba_n_heads
if mamba_d_head * mamba_n_heads != mamba_intermediate:
raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
self.mamba_d_ssm = mamba_d_ssm
self.mamba_n_heads = mamba_n_heads
self.mamba_d_head = mamba_d_head
self.mamba_n_groups = mamba_n_groups
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_expand = mamba_expand
self.mamba_chunk_size = mamba_chunk_size
self.mamba_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.mamba_norm_before_gate = mamba_norm_before_gate
self.mamba_rms_norm = mamba_rms_norm
self.lm_head_multiplier = lm_head_multiplier
self.embedding_multiplier = embedding_multiplier
if mlp_multipliers is not None:
self.mlp_multipliers = mlp_multipliers
else:
self.mlp_multipliers = [1.0, 1.0]
if attention_out_multiplier is not None:
self.attention_out_multiplier = attention_out_multiplier
else:
self.attention_out_multiplier = 1.0
if attention_in_multiplier is not None:
self.attention_in_multiplier = attention_in_multiplier
else:
self.attention_in_multiplier = 1.0
if key_multiplier is not None:
self.key_multiplier = key_multiplier
else:
self.key_multiplier = 1.0
if ssm_multipliers is not None:
self.ssm_multipliers = ssm_multipliers
else:
#
self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0]
if ssm_in_multiplier is not None:
self.ssm_in_multiplier = ssm_in_multiplier
else:
self.ssm_in_multiplier = 1.0
if ssm_out_multiplier is not None:
self.ssm_out_multiplier = ssm_out_multiplier
else:
self.ssm_out_multiplier = 1.0
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def layers_block_type(self):
return ["attention" for i in range(self.num_hidden_layers)]
__all__ = ["FalconH1Config"]

View File

@ -0,0 +1,151 @@
# coding=utf-8
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FalconH1Config, FalconH1ForCausalLM
CONVERSION_MAPPING = {
"backbone": "model",
"embeddings": "embed_tokens",
"mixer.": "",
"mixer_ssm": "mamba",
"mixer_attn": "self_attn",
"mlp.": "feed_forward.",
"mlp_norm": "pre_ff_layernorm",
"ssm_proj": "mamba.in_proj",
"attn_out_proj": "o_proj",
".norm.": ".input_layernorm.",
".mamba.input_layernorm.": ".mamba.norm.",
".ssm_out_proj.": ".mamba.out_proj.",
"norm_f": "final_layernorm",
}
def convert_falcon_h1_to_hf(input_model_path, output_path):
tokenizer = AutoTokenizer.from_pretrained(input_model_path)
model = AutoModelForCausalLM.from_pretrained(
input_model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, low_cpu_mem_usage=True
)
intermediate_size = int(model.config.expansion_factor * model.config.hidden_size)
if intermediate_size % 2 != 0:
intermediate_size = intermediate_size + (intermediate_size % 2)
new_config = FalconH1Config(
vocab_size=model.config.vocab_size,
tie_word_embeddings=model.config.tie_word_embeddings,
hidden_size=model.config.hidden_size,
intermediate_size=intermediate_size,
mamba_d_state=model.config.state_size,
num_hidden_layers=model.config.num_hidden_layers,
mamba_use_mlp=model.config.use_mlp,
rms_norm_eps=model.config.layer_norm_epsilon,
pad_token_id=model.config.pad_token_id,
eos_token_id=model.config.eos_token_id,
mamba_expand=model.config.expand,
mamba_d_conv=model.config.conv_kernel,
mamba_n_groups=model.config.n_groups,
mamba_n_heads=model.config.num_heads,
mamba_norm_before_gate=model.config.norm_before_gate,
mamba_rms_norm=model.config.rms_norm,
mamba_d_ssm=model.config.d_ssm,
attention_bias=model.config.use_bias,
projectors_bias=model.config.use_bias,
mamba_conv_bias=model.config.use_conv_bias,
hidden_act=model.config.hidden_act,
use_cache=model.config.use_cache,
mamba_chunk_size=model.config.chunk_size,
num_attention_heads=model.config.num_heads_mha,
num_key_value_heads=model.config.num_key_value_heads,
head_dim=model.config.head_dim_mha,
lm_head_multiplier=model.config.lm_head_multiplier,
embedding_multiplier=model.config.embedding_multiplier,
mlp_multipliers=model.config.mlp_multipliers,
key_multiplier=model.config.key_multiplier,
attention_out_multiplier=model.config.attention_out_multiplier,
attention_in_multiplier=model.config.attention_in_multiplier,
ssm_multipliers=model.config.ssm_multipliers,
ssm_in_multiplier=model.config.ssm_in_multiplier,
ssm_out_multiplier=model.config.ssm_out_multiplier,
rope_theta=model.config.rope_theta,
)
old_state_dict = model.state_dict()
new_state_dict = {}
for old_key, old_value in old_state_dict.items():
new_key = old_key
for conversion_key, conversion_value in CONVERSION_MAPPING.items():
if conversion_key in old_key:
new_key = new_key.replace(conversion_key, conversion_value)
if "mamba.input_layernorm" in new_key:
new_key = new_key.replace("mamba.input_layernorm", "mamba.norm")
# Special processing for attention layers
if "self_attn.attn_proj" in new_key:
num_heads = new_config.num_attention_heads
num_kv_heads = new_config.num_key_value_heads
head_dim = new_config.head_dim
q_proj, k_proj, v_proj = old_value.split(
[
num_heads * head_dim,
num_kv_heads * head_dim,
num_kv_heads * head_dim,
],
dim=0,
)
new_state_dict[new_key.replace("attn_proj", "q_proj")] = q_proj
new_state_dict[new_key.replace("attn_proj", "k_proj")] = k_proj
new_state_dict[new_key.replace("attn_proj", "v_proj")] = v_proj
else:
new_state_dict[new_key] = old_value
with torch.device("meta"):
new_model = FalconH1ForCausalLM(new_config)
del model
new_model.load_state_dict(new_state_dict, strict=True, assign=True)
new_model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba_ssm_checkpoint_directory",
type=str,
required=True,
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
args = parser.parse_args()
convert_falcon_h1_to_hf(
args.mamba_ssm_checkpoint_directory,
args.output_dir,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1062,10 +1062,21 @@ class Gemma3Model(Gemma3PreTrainedModel):
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
# Find where a new image block starts: 1 if image and previous not image
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
is_image = token_type_ids == 1
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
same_image_mask[image_group_ids == -1] = False # remove non-image
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
token_type_mask, 0.0
image_mask, 0.0
)
if attention_mask is not None:

View File

@ -781,10 +781,21 @@ class Gemma3Model(PaliGemmaModel):
if token_type_ids is not None and sequence_length != 1:
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
# Find where a new image block starts: 1 if image and previous not image
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
is_image = token_type_ids == 1
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
same_image_mask[image_group_ids == -1] = False # remove non-image
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
causal_mask = causal_mask.clone()
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
token_type_mask, 0.0
image_mask, 0.0
)
if attention_mask is not None:

View File

@ -439,6 +439,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
):
# 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
@ -521,7 +522,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
@ -562,6 +563,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=seq_idx,
).transpose(1, 2)
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
@ -581,7 +583,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=None,
seq_idx=seq_idx,
return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
@ -815,9 +817,15 @@ class GraniteMoeHybridMambaLayer(nn.Module):
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.IntTensor] = None,
**kwargs,
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
if seq_idx is not None:
raise NotImplementedError(
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66

View File

@ -144,7 +144,7 @@ class Llama4TextMoe(nn.Module):
def forward(self, hidden_states):
batch, seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_dim)
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
tokens_per_expert = batch * seq_len
@ -258,6 +258,33 @@ def eager_attention_forward(
return attn_output, attn_weights
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
def vision_eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Llama4TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@ -534,10 +561,10 @@ class Llama4TextModel(Llama4PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
if use_cache and past_key_values is None:
if self.config.get_text_config().get("attention_chunk_size") is not None:
if self.config.get_text_config().attention_chunk_size is not None:
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
else:
past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@ -730,7 +757,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
chunked_attention_mask = chunked_attention_mask.bool()
causal_mask = causal_mask.bool()
causal_mask = causal_mask != torch.finfo(dtype).min
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@ -1099,7 +1126,7 @@ class Llama4VisionAttention(nn.Module):
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attention_interface: Callable = eager_attention_forward
attention_interface: Callable = vision_eager_attention_forward
# flex disable because breaks on TP 8, embed is 88 not power of 2
if self.config._attn_implementation not in ["eager", "flex_attention"]:
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@ -1117,7 +1144,7 @@ class Llama4VisionAttention(nn.Module):
value_states,
None,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=None,
scaling=None, # TODO Might be enforced here for TP compatibility as scaling is not just sqrt(head_dim)
is_causal=False, # HAS TO BE ENFORCED
**kwargs,
)

View File

@ -424,19 +424,23 @@ class LlavaNextImageProcessor(BaseImageProcessor):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
) -> np.array:
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
padded_image = self.pad(image, padding=padding)
return padded_image

View File

@ -141,19 +141,23 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
padded_image = F.pad(image, padding=padding)
return padded_image

View File

@ -315,6 +315,14 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
return resized_image
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._get_padding_size
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_patching
def _pad_for_patching(
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
@ -322,13 +330,10 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
padded_image = self.pad(image, padding=padding)
return padded_image
@ -437,6 +442,85 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
return pixel_values
# Copied from transformers.models.llava.image_processing_llava.LlavaImageProcessor.pad_to_square
def pad_to_square(
self,
image: np.ndarray,
background_color: Union[int, Tuple[int, int, int]] = 0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Pads an image to a square based on the longest edge.
Args:
image (`np.ndarray`):
The image to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use the inferred format of the input image.
Returns:
`np.ndarray`: The padded image.
"""
height, width = get_image_size(image, input_data_format)
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
if height == width:
image = (
to_channel_dimension_format(image, data_format, input_data_format)
if data_format is not None
else image
)
return image
max_dim = max(height, width)
# Ensure background_color is the correct shape
if isinstance(background_color, int):
background_color = [background_color]
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
if input_data_format == ChannelDimension.FIRST:
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
for i, color in enumerate(background_color):
result[i, :, :] = color
if width > height:
start = (max_dim - height) // 2
result[:, start : start + height, :] = image
else:
start = (max_dim - width) // 2
result[:, :, start : start + width] = image
else:
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
for i, color in enumerate(background_color):
result[:, :, i] = color
if width > height:
start = (max_dim - height) // 2
result[start : start + height, :, :] = image
else:
start = (max_dim - width) // 2
result[:, start : start + width, :] = image
image = (
to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result
)
return image
def _preprocess(
self,
images: ImageInput,
@ -595,6 +679,17 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
do_pad = do_pad if do_pad is not None else self.do_pad
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
# if the first element is a list, we assume that all elements are lists
batch_num_images = [len(x) for x in images]
elif isinstance(images, (tuple, list)):
# treat this as a single-image case for backward compatibility
batch_num_images = [1] * len(images)
else:
batch_num_images = [1]
# only single image patching is supported
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
images = make_flat_list_of_images(images)
if not valid_images(images):
@ -630,25 +725,34 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
size_tuple = (
(size["height"], size["width"])
if "height" in size and "width" in size
else (size["shortest_edge"], size["shortest_edge"])
)
new_images = []
image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
for image in images:
# convert image into a list of patches
# we intentionally use the same data format as the input data format
size_tuple = (
(size["height"], size["width"])
if "height" in size and "width" in size
else (size["shortest_edge"], size["shortest_edge"])
)
image_patches = self.get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=size_tuple[0],
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,
)
for i, image in enumerate(images):
if need_patching[i]:
# convert image into a list of patches
# we intentionally use the same data format as the input data format
image_patches = self.get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=size_tuple[0],
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,
)
else:
padded_image = self.pad_to_square(
image=image,
background_color=tuple(int(x * 255) for x in self.image_mean),
input_data_format=input_data_format,
)
image_patches = [padded_image]
# preprocess patches
pixel_values = self._preprocess(
@ -671,7 +775,8 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
processed_images = self._pad_for_batching(new_images)
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
tensor_type=return_tensors,
)

View File

@ -19,7 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
@ -89,6 +89,15 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
# if the first element is a list, we assume that all elements are lists
batch_num_images = [len(x) for x in images]
elif isinstance(images, (tuple, list)):
# treat this as a single-image case for backward compatibility
batch_num_images = [1] * len(images)
else:
batch_num_images = [1]
kwargs["batch_num_images"] = batch_num_images
return super().preprocess(images, **kwargs)
def _prepare_images_structure(
@ -137,19 +146,23 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
return resized_image
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
original_height, original_width = original_resolution
target_height, target_width = target_resolution
paste_x, r_x = divmod(target_width - original_width, 2)
paste_y, r_y = divmod(target_height - original_height, 2)
return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
padding = self._get_padding_size(new_resolution, target_resolution)
paste_x, r_x = divmod(target_width - new_width, 2)
paste_y, r_y = divmod(target_height - new_height, 2)
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
padded_image = F.pad(image, padding=padding)
return padded_image
@ -234,10 +247,15 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
batch_num_images: List[int],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
processed_images = []
image_sizes = []
# only single image patching is supported
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
@ -252,14 +270,20 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
else:
patch_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
interpolation=interpolation,
)
for i, image in enumerate(images):
if need_patching[i]:
image_patches = self._get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
interpolation=interpolation,
)
else:
padded_image = self.pad_to_square(
images=image, background_color=tuple(int(x * 255) for x in self.image_mean)
)
image_patches = [padded_image]
# Group images by size for batched processing
processed_image_patches_grouped = {}
@ -289,8 +313,52 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
processed_images = self._pad_for_batching(processed_images)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
tensor_type=return_tensors,
)
# Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square
def pad_to_square(
self,
images: "torch.Tensor",
background_color: Union[int, Tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to a square based on the longest edge.
Args:
images (`np.ndarray`):
The images to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
Returns:
`torch.Tensor`: The padded images.
"""
height, width = get_image_size(images, ChannelDimension.FIRST)
if height == width:
return images
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
if isinstance(background_color, int):
background_color = [background_color] + [0] * (num_channels - 1)
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
max_dim = max(height, width)
paste_x_left = (max_dim - width) // 2
paste_y_left = (max_dim - height) // 2
paste_x_right = max_dim - width - paste_x_left
paste_y_right = max_dim - height - paste_y_left
padded_images = F.pad(
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
)
return padded_images
__all__ = ["LlavaOnevisionImageProcessorFast"]

View File

@ -419,8 +419,9 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
batch_num_images: Optional[torch.LongTensor] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
@ -430,34 +431,34 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`, *optional*):
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`, *optional*):
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
# ! infer image_num_patches from image_sizes
if batch_num_images is None:
# treat this as a single-image case for backward compatibility
need_patching = [True] * len(image_sizes)
else:
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
if should_patch
else 1
for imsize, should_patch in zip(image_sizes, need_patching)
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
@ -500,6 +501,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
batch_num_images: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -520,6 +522,8 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -558,6 +562,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
batch_num_images=batch_num_images,
)
image_features, feature_lens = self.pack_image_features(
image_features,
@ -749,6 +754,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
batch_num_images: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -771,6 +777,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -832,6 +840,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
vision_aspect_ratio=vision_aspect_ratio,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
batch_num_images=batch_num_images,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,

View File

@ -28,18 +28,59 @@ from transformers.models.llava_next_video.modeling_llava_next_video import (
LlavaNextVideoModelOutputWithPast,
LlavaNextVideoPreTrainedModel,
get_anyres_image_grid_shape,
image_size_to_num_patches,
unpad_image,
)
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs, group_images_by_shape, reorder_images
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
from ...utils import (
TensorType,
auto_docstring,
can_return_tuple,
is_torchdynamo_compiling,
is_torchvision_available,
is_torchvision_v2_available,
logging,
)
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
logger = logging.get_logger(__name__)
class LlavaOnevisionFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
"""
image_grid_pinpoints: Optional[List[List[int]]]
do_pad: Optional[bool]
class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = OPENAI_CLIP_MEAN
@ -56,6 +97,147 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
image_grid_pinpoints = [[384, 384], [384, 768], [384, 1152], [384, 1536], [384, 1920], [384, 2304], [768, 384], [768, 768], [768, 1152], [768, 1536], [768, 1920], [768, 2304], [1152, 384], [1152, 768], [1152, 1152], [1152, 1536], [1152, 1920], [1152, 2304], [1536, 384], [1536, 768], [1536, 1152], [1536, 1536], [1536, 1920], [1536, 2304], [1920, 384], [1920, 768], [1920, 1152], [1920, 1536], [1920, 1920], [1920, 2304], [2304, 384], [2304, 768], [2304, 1152], [2304, 1536], [2304, 1920], [2304, 2304]] # fmt: skip
model_input_names = ["pixel_values_videos"]
# Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square
def pad_to_square(
self,
images: "torch.Tensor",
background_color: Union[int, Tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to a square based on the longest edge.
Args:
images (`np.ndarray`):
The images to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
Returns:
`torch.Tensor`: The padded images.
"""
height, width = get_image_size(images, ChannelDimension.FIRST)
if height == width:
return images
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
if isinstance(background_color, int):
background_color = [background_color] + [0] * (num_channels - 1)
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
max_dim = max(height, width)
paste_x_left = (max_dim - width) // 2
paste_y_left = (max_dim - height) // 2
paste_x_right = max_dim - width - paste_x_left
paste_y_right = max_dim - height - paste_y_left
padded_images = F.pad(
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
)
return padded_images
@auto_docstring
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
# if the first element is a list, we assume that all elements are lists
batch_num_images = [len(x) for x in images]
elif isinstance(images, (tuple, list)):
# treat this as a single-image case for backward compatibility
batch_num_images = [1] * len(images)
else:
batch_num_images = [1]
kwargs["batch_num_images"] = batch_num_images
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
do_resize: bool,
size: SizeDict,
image_grid_pinpoints: List[List[int]],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_pad: bool,
batch_num_images: List[int],
return_tensors: Optional[Union[str, TensorType]],
) -> BatchFeature:
processed_images = []
image_sizes = []
# only single image patching is supported
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
patch_size = crop_size.height
elif size and size.height:
patch_size = size.height
else:
patch_size = size.shortest_edge
for i, image in enumerate(images):
if need_patching[i]:
image_patches = self._get_image_patches(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=patch_size,
interpolation=interpolation,
)
else:
padded_image = self.pad_to_square(
images=image, background_color=tuple(int(x * 255) for x in self.image_mean)
)
image_patches = [padded_image]
# Group images by size for batched processing
processed_image_patches_grouped = {}
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
tensor_type=return_tensors,
)
class LlavaOnevisionModelOutputWithPast(LlavaNextVideoModelOutputWithPast):
pass
@ -154,6 +336,76 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
image_features = image_features.view(batch_frames, -1, dim)
return image_features
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_sizes: torch.Tensor,
vision_feature_layer: Union[int, List[int]],
vision_feature_select_strategy: str,
batch_num_images: Optional[torch.LongTensor] = None,
):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_layer (`Union[int, List[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
"""
# ! infer image_num_patches from image_sizes
if batch_num_images is None:
# treat this as a single-image case for backward compatibility
need_patching = [True] * len(image_sizes)
else:
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
if should_patch
else 1
for imsize, should_patch in zip(image_sizes, need_patching)
]
if pixel_values.dim() == 5:
# stacked if input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
# If we have one vision feature layer, return the corresponding hidden states,
# otherwise, select the hidden states of each feature layer and concatenate them
if isinstance(vision_feature_layer, int):
selected_image_feature = image_features.hidden_states[vision_feature_layer]
else:
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
selected_image_feature = torch.cat(hs_pool, dim=-1)
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature)
image_features = torch.split(image_features, image_num_patches, dim=0)
return image_features
def get_video_features(
self,
pixel_values: torch.FloatTensor,
@ -214,6 +466,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
batch_num_images: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
@ -234,6 +487,8 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@ -272,6 +527,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
image_sizes,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
batch_num_images=batch_num_images,
)
image_features, feature_lens = self.pack_image_features(
image_features,
@ -355,6 +611,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
vision_aspect_ratio: Optional[str] = None,
batch_num_images: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@ -377,6 +634,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
If `"full"`, the full vision features are used.
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
Aspect ratio used when processong image features. The default value is "anyres_max_9".
batch_num_images (`torch.LongTensor`, *optional*):
Number of images in each sample.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -438,6 +697,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
vision_aspect_ratio=vision_aspect_ratio,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
batch_num_images=batch_num_images,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,

View File

@ -170,12 +170,15 @@ class LlavaOnevisionProcessor(ProcessorMixin):
if images is not None:
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
batch_num_images = iter(image_inputs["batch_num_images"])
image_sizes = iter(image_inputs["image_sizes"])
height, width = get_image_size(
to_numpy_array(image_inputs["pixel_values"][0][0]),
channel_dim=output_kwargs["images_kwargs"].get("data_format"),
)
text, num_image_tokens = self._expand_image_tokens(text, image_sizes, height, width, self.image_token)
text, num_image_tokens = self._expand_image_tokens(
text, image_sizes, height, width, self.image_token, batch_num_images
)
if videos is not None:
video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
@ -205,23 +208,29 @@ class LlavaOnevisionProcessor(ProcessorMixin):
height: int,
width: int,
special_token: str,
num_frames: int = 1,
batch_num_images: Iterable[int],
):
prompt_strings = []
max_num_vision_tokens = 0
for sample in text:
if special_token in sample:
is_multi_image = next(batch_num_images) != 1
else:
is_multi_image = False
while special_token in sample:
image_size_list = next(image_sizes)
original_size = image_size_list[0] if num_frames != 1 else image_size_list
if not isinstance(original_size, (list, tuple)):
# cast to list to avoid numerical precision errors when calculating unpadding
original_size = original_size.tolist()
orig_height, orig_width = original_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if is_multi_image:
num_image_tokens = self.num_image_tokens + 1 # one for image_newline
else:
original_size = next(image_sizes)
if not isinstance(original_size, (list, tuple)):
# cast to list to avoid numerical precision errors when calculating unpadding
original_size = original_size.tolist()
orig_height, orig_width = original_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
max_num_vision_tokens = max(max_num_vision_tokens, num_image_tokens)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(special_token, "<placeholder>" * num_image_tokens * num_frames, 1)
sample = sample.replace(special_token, "<placeholder>" * num_image_tokens, 1)
prompt_strings.append(sample)
text = [sample.replace("<placeholder>", special_token) for sample in prompt_strings]
return text, max_num_vision_tokens

View File

@ -486,8 +486,6 @@ class MllamaTextCrossAttention(nn.Module):
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
if past_key_value is not None:
@ -850,7 +848,7 @@ class MllamaRotaryEmbedding(nn.Module):
@auto_docstring
class MllamaPreTrainedModel(PreTrainedModel):
config_class = MllamaConfig
base_model_prefix = "model"
base_model_prefix = ""
supports_gradient_checkpointing = True
_no_split_modules = [
"MllamaVisionEncoderLayer",

View File

@ -40,7 +40,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import PreTrainedModel
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
@ -358,7 +358,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.get_text_config().initializer_range
@ -1659,9 +1659,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
@ -1676,9 +1676,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
@ -1694,20 +1694,32 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
if position_ids is None:
attention_mask_2d = attention_mask
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
attention_mask_2d = (1.0 - attention_mask_2d).int()
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask_2d,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
@ -1747,6 +1759,61 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
)
return output if return_dict else output.to_tuple()
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@dataclass
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
@ -2108,60 +2175,5 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
return input_ids, model_kwargs
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"]

View File

@ -50,7 +50,7 @@ from ...image_utils import ImageInput
from ...modeling_flash_attention_utils import is_flash_attn_available
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
from ...utils import is_torchdynamo_compiling, logging
from ...video_utils import VideoInput
@ -647,9 +647,9 @@ class Qwen2_5_VLModel(Qwen2VLModel):
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_tokens = (input_ids == self.config.image_token_id).sum()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
@ -664,9 +664,9 @@ class Qwen2_5_VLModel(Qwen2VLModel):
if pixel_values_videos is not None:
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_tokens = (input_ids == self.config.video_token_id).sum()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
@ -682,20 +682,32 @@ class Qwen2_5_VLModel(Qwen2VLModel):
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
if position_ids is None:
attention_mask_2d = attention_mask
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
attention_mask_2d = (1.0 - attention_mask_2d).int()
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask_2d,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids

View File

@ -924,7 +924,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.get_text_config().initializer_range
@ -1616,16 +1616,28 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
if position_ids is None:
attention_mask_2d = attention_mask
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
attention_mask_2d = (1.0 - attention_mask_2d).int()
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
input_ids, image_grid_thw, video_grid_thw, attention_mask_2d
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
@ -1662,6 +1674,62 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
)
return output if return_dict else output.to_tuple()
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
@ -1974,61 +2042,5 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
return input_ids, model_kwargs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"]

View File

@ -972,7 +972,9 @@ class Trainer:
)
return remove_columns_collator
def _get_train_sampler(self, train_dataset) -> Optional[torch.utils.data.Sampler]:
def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
if train_dataset is None:
train_dataset = self.train_dataset
if train_dataset is None or not has_length(train_dataset):
return None

View File

@ -36,7 +36,9 @@ BLACK_SQUARE = "■"
WHITE_SQUARE = ""
def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_window=None, token_type_ids=None):
def generate_attention_matrix_from_mask(
words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
):
"""
Generates an attention matrix from a given attention mask.
@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
for j in range(n)
)
if token_type_ids is not None:
is_special = token_type_ids == 1
token_type_buckets = torch.where(
(token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
)
boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
# Print headers
legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
output.append(" " + legend)
@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
if sliding_window is not None
else ""
)
for i, word in enumerate(words):
word_repr = repr(word).ljust(max_word_length)
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
if sliding_window is not None:
sliding_window_row = " ".join(
f"{YELLOW}{BLACK_SQUARE}{RESET}"
if img_token in words[j] and img_token in words[i]
if img_token in words[j]
and img_token in words[i]
and token_type_buckets[0, i] == token_type_buckets[0, j]
else f"{GREEN}{BLACK_SQUARE}{RESET}"
if i == j
else BLACK_SQUARE
@ -170,7 +181,8 @@ class AttentionMaskVisualizer:
if self.config.model_type in PROCESSOR_MAPPING_NAMES:
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
img = Image.open(requests.get(img, stream=True).raw)
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5)
image_seq_length = 5
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
if hasattr(processor, "image_token"):
image_token = processor.image_token
else:
@ -179,7 +191,7 @@ class AttentionMaskVisualizer:
if image_token:
input_sentence = input_sentence.replace("<img>", image_token)
inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt")
inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
@ -223,6 +235,7 @@ class AttentionMaskVisualizer:
img_token=self.image_token,
sliding_window=getattr(self.config, "sliding_window", None),
token_type_ids=kwargs.get("token_type_ids", None),
image_seq_length=image_seq_length,
)
print(f_string)
print(f"{top_bottom_border}")

View File

@ -1222,6 +1222,9 @@ def is_keras_nlp_available():
def is_in_notebook():
try:
# Check if we are running inside Marimo
if "marimo" in sys.modules:
return True
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:

View File

@ -180,6 +180,7 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
all_model_classes = (AriaModel, AriaForConditionalGeneration) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
test_torchscript = False
_is_composite = True
def setUp(self):

Some files were not shown because too many files have changed in this diff Show More