mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge branch 'main' into fixing_gptq_tests
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled
This commit is contained in:
commit
cb7df519b9
51
.github/workflows/check_failed_model_tests.yml
vendored
51
.github/workflows/check_failed_model_tests.yml
vendored
@ -39,55 +39,100 @@ jobs:
|
||||
name: ci_results_run_models_gpu
|
||||
path: /transformers/ci_results_run_models_gpu
|
||||
|
||||
- name: Check file
|
||||
working-directory: /transformers
|
||||
run: |
|
||||
if [ -f ci_results_run_models_gpu/new_model_failures.json ]; then
|
||||
echo "`ci_results_run_models_gpu/new_model_failures.json` exists, continue ..."
|
||||
echo "process=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "`ci_results_run_models_gpu/new_model_failures.json` doesn't exist, abort."
|
||||
echo "process=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- uses: actions/download-artifact@v4
|
||||
if: ${{ env.process == 'true' }}
|
||||
with:
|
||||
pattern: setup_values*
|
||||
path: setup_values
|
||||
merge-multiple: true
|
||||
|
||||
- name: Prepare some setup values
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
if [ -f setup_values/prev_workflow_run_id.txt ]; then
|
||||
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
if [ -f setup_values/other_workflow_run_id.txt ]; then
|
||||
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: git fetch && git checkout ${{ github.sha }}
|
||||
|
||||
- name: Get target commit
|
||||
working-directory: /transformers/utils
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"]); print(commit)')" >> $GITHUB_ENV
|
||||
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"], workflow_run_id=os.environ["PREV_WORKFLOW_RUN_ID"]); print(commit)')" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout to `start_sha`
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: git fetch && git checkout ${{ inputs.start_sha }}
|
||||
|
||||
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Environment
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
python3 utils/print_env.py
|
||||
|
||||
- name: Show installed libraries and their versions
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: pip freeze
|
||||
|
||||
- name: Check failed tests
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: python3 utils/check_bad_commit.py --start_commit ${{ inputs.start_sha }} --end_commit ${{ env.END_SHA }} --file ci_results_run_models_gpu/new_model_failures.json --output_file new_model_failures_with_bad_commit.json
|
||||
|
||||
- name: Show results
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
ls -l new_model_failures_with_bad_commit.json
|
||||
cat new_model_failures_with_bad_commit.json
|
||||
|
||||
- name: Checkout back
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
git checkout ${{ inputs.start_sha }}
|
||||
|
||||
- name: Process report
|
||||
shell: bash
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
env:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
|
||||
run: |
|
||||
python3 utils/process_bad_commit_report.py
|
||||
@ -95,7 +140,9 @@ jobs:
|
||||
- name: Process report
|
||||
shell: bash
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
env:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
|
||||
run: |
|
||||
{
|
||||
@ -105,7 +152,7 @@ jobs:
|
||||
} >> "$GITHUB_ENV"
|
||||
|
||||
- name: Send processed report
|
||||
if: ${{ !endsWith(env.REPORT_TEXT, '{}') }}
|
||||
if: ${{ env.process == 'true' && !endsWith(env.REPORT_TEXT, '{}') }}
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
|
35
.github/workflows/self-scheduled-caller.yml
vendored
35
.github/workflows/self-scheduled-caller.yml
vendored
@ -8,8 +8,43 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- run_scheduled_ci*
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
prev_workflow_run_id:
|
||||
description: 'previous workflow run id to compare'
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
other_workflow_run_id:
|
||||
description: 'other workflow run id to compare'
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
|
||||
|
||||
# Used for `push` to easily modiffy the target workflow runs to compare against
|
||||
env:
|
||||
prev_workflow_run_id: ""
|
||||
other_workflow_run_id: ""
|
||||
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Setup
|
||||
run: |
|
||||
mkdir "setup_values"
|
||||
echo "${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}" > "setup_values/prev_workflow_run_id.txt"
|
||||
echo "${{ inputs.other_workflow_run_id || env.other_workflow_run_id }}" > "setup_values/other_workflow_run_id.txt"
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: setup_values
|
||||
path: setup_values
|
||||
|
||||
model-ci:
|
||||
name: Model CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
|
18
.github/workflows/slack-report.yml
vendored
18
.github/workflows/slack-report.yml
vendored
@ -39,6 +39,21 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/download-artifact@v4
|
||||
|
||||
- name: Prepare some setup values
|
||||
run: |
|
||||
if [ -f setup_values/prev_workflow_run_id.txt ]; then
|
||||
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
if [ -f setup_values/other_workflow_run_id.txt ]; then
|
||||
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Send message to Slack
|
||||
if: ${{ inputs.job != 'run_quantization_torch_gpu' }}
|
||||
env:
|
||||
@ -50,7 +65,6 @@ jobs:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
CI_EVENT: ${{ inputs.ci_event }}
|
||||
CI_SHA: ${{ github.sha }}
|
||||
CI_WORKFLOW_REF: ${{ github.workflow_ref }}
|
||||
CI_TEST_JOB: ${{ inputs.job }}
|
||||
SETUP_STATUS: ${{ inputs.setup_status }}
|
||||
# We pass `needs.setup.outputs.matrix` as the argument. A processing in `notification_service.py` to change
|
||||
@ -58,7 +72,6 @@ jobs:
|
||||
# For a job that doesn't depend on (i.e. `needs`) `setup`, the value for `inputs.folder_slices` would be an
|
||||
# empty string, and the called script still get one argument (which is the emtpy string).
|
||||
run: |
|
||||
sudo apt-get install -y curl
|
||||
pip install huggingface_hub
|
||||
pip install slack_sdk
|
||||
pip show slack_sdk
|
||||
@ -86,7 +99,6 @@ jobs:
|
||||
# We pass `needs.setup.outputs.quantization_matrix` as the argument. A processing in `notification_service_quantization.py` to change
|
||||
# `quantization/bnb` to `quantization_bnb` is required, as the artifact names use `_` instead of `/`.
|
||||
run: |
|
||||
sudo apt-get install -y curl
|
||||
pip install huggingface_hub
|
||||
pip install slack_sdk
|
||||
pip show slack_sdk
|
||||
|
@ -455,6 +455,8 @@
|
||||
title: Falcon
|
||||
- local: model_doc/falcon3
|
||||
title: Falcon3
|
||||
- local: model_doc/falcon_h1
|
||||
title: FalconH1
|
||||
- local: model_doc/falcon_mamba
|
||||
title: FalconMamba
|
||||
- local: model_doc/flan-t5
|
||||
|
@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
|
||||
<!---
|
||||
## Usage Tips
|
||||
|
||||
Tips:
|
||||
Tips:
|
||||
|
||||
- The architecture is based on Mamba-2 models.
|
||||
|
||||
@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
|
||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||
```
|
||||
|
||||
|
||||
## Padding-Free Training
|
||||
|
||||
Bamba supports padding-free training in which distinct training examples can be concatenated
|
||||
together while nevertheless processing the inputs as though they belonged to separate batches. When
|
||||
the examples are of varying lengths, padding-free training can provide significant speed ups and
|
||||
memory savings compared to batching the examples together and using padding, as the unnecessary
|
||||
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
|
||||
as the model and the data distribution, but throughput gains up to [~2x are commonly
|
||||
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
|
||||
|
||||
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
|
||||
packages, and the following arguments must be passed to the model in addition to `input_ids` and
|
||||
`labels`:
|
||||
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
|
||||
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
|
||||
* Each of the [`FlashAttentionKwargs`]
|
||||
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
|
||||
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
|
||||
* `max_length_q: int`: the longest query length in the batch.
|
||||
* `max_length_k: int`: the longest key length in the batch.
|
||||
|
||||
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
|
||||
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
|
||||
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
|
||||
for additional information.
|
||||
|
||||
|
||||
[[autodoc]] BambaForCausalLM
|
||||
- forward
|
||||
|
||||
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
|
||||
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
|
||||
|
65
docs/source/en/model_doc/falcon_h1.md
Normal file
65
docs/source/en/model_doc/falcon_h1.md
Normal file
@ -0,0 +1,65 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# FalconH1
|
||||
|
||||
## Overview
|
||||
|
||||
The FalconH1 model was developed by the TII Pretraining team. A comprehensive research paper covering the architecture, pretraining dynamics, experimental results, and conclusions is forthcoming. You can read more about this series in [this website](https://github.com/tiiuae/Falcon-H1).
|
||||
|
||||
## Contributors
|
||||
|
||||
This model was contributed by [DhiyaEddine](https://huggingface.co/DhiyaEddine), [ybelkada](https://huggingface.co/ybelkada), [JingweiZuo](https://huggingface.co/JingweiZuo), [IlyasChahed](https://huggingface.co/IChahed), and [MaksimVelikanov](https://huggingface.co/yellowvm).
|
||||
The original code can be found [here](https://github.com/tiiuae/Falcon-H1).
|
||||
|
||||
|
||||
## FalconH1Config
|
||||
|
||||
| Model | Depth | Dim | Attn Heads | KV | Mamba Heads | d_head | d_state | Ctx Len |
|
||||
|-----------|--------|------|------------|----|--------------|--------------|------|-----------------|
|
||||
| H1 0.5B | 36 | 1024 | 8 | 2 | 24 | 64 / 64 | 128 | 4K, 16K-SFT |
|
||||
| H1 1.5B | 24 | 2048 | 8 | 2 | 48 | 128 / 64 | 256 | 128K |
|
||||
| H1 1.5B-d | 66 | 1280 | 6 | 2 | 24 | 128 / 64 | 256 | 128K |
|
||||
| H1 3B | 32 | 2560 | 10 | 2 | 32 | 128 / 128 | 256 | 128K |
|
||||
| H1 7B | 44 | 3072 | 12 | 2 | 24 | 128 / 128 | 256 | 256K |
|
||||
| H1 34B | 72 | 5120 | 20 | 4 | 32 | 128 / 128 | 256 | 256K |
|
||||
|
||||
|
||||
|
||||
[[autodoc]] FalconH1Config
|
||||
|
||||
<!---
|
||||
## Usage Tips
|
||||
Tips:
|
||||
- The architecture is based on Mamba-2 models.
|
||||
## FalconH1Model
|
||||
[[autodoc]] FalconH1Model
|
||||
- forward
|
||||
-->
|
||||
|
||||
## FalconH1ForCausalLM
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
|
||||
|
||||
message = ["Mamba is a snake with following properties "]
|
||||
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
|
||||
response = model.generate(**inputs, max_new_tokens=64)
|
||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||
```
|
||||
|
||||
[[autodoc]] FalconH1ForCausalLM
|
||||
- forward
|
||||
|
||||
This HF implementation is contributed by [younesbelkada](https://github.com/younesbelkada) and [DhiaEddineRhaiem](https://github.com/dhiaEddineRhaiem).
|
@ -147,7 +147,7 @@ print(processor.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
### Multi image inference
|
||||
|
||||
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. Here is how you can do it:
|
||||
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. For multi-image cases, we recommend using a **nested list of images** as input. Otherwise, every image will be patchified and consume a lot of memory. Here is how you can do it:
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
435
examples/3D_parallel.py
Normal file
435
examples/3D_parallel.py
Normal file
@ -0,0 +1,435 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""":
|
||||
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
|
||||
|
||||
Usage:
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
export CUDA_VISIBLE_DEVICES=5,6,7
|
||||
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
|
||||
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
|
||||
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
|
||||
DP_SIZE=2 CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=8 examples/3D_parallel.py
|
||||
|
||||
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
|
||||
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
|
||||
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
|
||||
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=localhost:29504 examples/3D_parallel.py
|
||||
ocalhost:29504 test_train.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
# set_rotate_method("alltoall") # CP rotate shards using all-to-all
|
||||
|
||||
|
||||
def main():
|
||||
tp_size = int(os.environ.get("TP_SIZE", 1))
|
||||
dp_size = int(os.environ.get("DP_SIZE", 1))
|
||||
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
|
||||
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
|
||||
# sdpa_backend = SDPBackend.MATH # For CP
|
||||
global_batch_size = 8 # Desired global batch size
|
||||
seq_len = 1024 # Sequence length
|
||||
num_train_steps = 10000 # Number of training steps
|
||||
LR = 1e-5
|
||||
model_name = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
# model_name = "unsloth/Llama-3.2-1B"
|
||||
|
||||
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
|
||||
|
||||
# Initialize distributed environment
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
dist.init_process_group("nccl")
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
assert world_size == tp_size * dp_size * cp_size, (
|
||||
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
|
||||
)
|
||||
|
||||
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
|
||||
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
|
||||
tp_mesh = world_mesh["tp"]
|
||||
dp_mesh = world_mesh["dp"]
|
||||
cp_mesh = world_mesh["cp"]
|
||||
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
|
||||
logger.info(f"Created DeviceMesh: {world_mesh}")
|
||||
logger.info(
|
||||
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
|
||||
)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
wandb.init(
|
||||
project="tp_dp_test",
|
||||
config={
|
||||
"tp_size": tp_size,
|
||||
"dp_size": dp_size,
|
||||
"cp_size": cp_size,
|
||||
"global_batch_size": global_batch_size,
|
||||
"model_name": model_name,
|
||||
"dataset": "roneneldan/TinyStories-1M",
|
||||
"seq_len": seq_len,
|
||||
"lr": LR,
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
|
||||
if model_name == "unsloth/Llama-3.2-1B"
|
||||
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
|
||||
)
|
||||
logger.info("Wandb initialized.")
|
||||
# Log the current file to wandb
|
||||
wandb.save("test_train.py")
|
||||
|
||||
# Load model and tokenizer
|
||||
logger.info(f"Loading model and tokenizer from {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_mesh=tp_mesh if dist.is_initialized() else None,
|
||||
tp_plan="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
logger.info(f"Using device: {device} for non-model tensors")
|
||||
use_ddp = False
|
||||
if dist.is_initialized() and dp_mesh.size() > 1:
|
||||
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
|
||||
use_ddp = True
|
||||
pass
|
||||
|
||||
model.train()
|
||||
|
||||
logger.info("Loading TinyStories dataset...")
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Tokenize the text without padding
|
||||
tokenized_batch = tokenizer(
|
||||
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
|
||||
)
|
||||
# Set labels to be the same as input_ids for Causal LM
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
|
||||
|
||||
# Create packed sequences
|
||||
def create_packed_sequences(examples):
|
||||
# Flatten all sequences
|
||||
all_tokens = []
|
||||
for input_ids in examples["input_ids"]:
|
||||
all_tokens.extend(input_ids)
|
||||
|
||||
# Split into sequences of seq_len + 1 (for input + label)
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
end_idx = start_idx + (seq_len + 1)
|
||||
# Get the full sequence
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
# For input_ids, remove the last token
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
# For labels, remove the first token
|
||||
packed_labels.append(full_sequence[1:])
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
|
||||
# Apply packing to the dataset
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=1000, # Process in batches for efficiency
|
||||
num_proc=60,
|
||||
)
|
||||
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
|
||||
|
||||
# Shuffle the packed dataset
|
||||
packed_dataset = packed_dataset.shuffle(seed=42)
|
||||
logger.info("Packed dataset shuffled")
|
||||
|
||||
# Calculate local batch size
|
||||
if dist.is_initialized():
|
||||
assert global_batch_size % dp_mesh.size() == 0, (
|
||||
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
|
||||
)
|
||||
local_batch_size = global_batch_size // dp_mesh.size()
|
||||
else:
|
||||
local_batch_size = global_batch_size
|
||||
|
||||
logger.info(
|
||||
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
|
||||
)
|
||||
|
||||
# Simple collate function since sequences are already packed
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
if dist.is_initialized():
|
||||
sampler = DistributedSampler(
|
||||
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
packed_dataset,
|
||||
batch_size=local_batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
|
||||
|
||||
# Training loop
|
||||
logger.info(f"Starting training for {num_train_steps} steps...")
|
||||
model.train()
|
||||
step = 0
|
||||
while step < num_train_steps:
|
||||
for batch in dataloader:
|
||||
if step >= num_train_steps:
|
||||
break # Exit loop if max steps reached
|
||||
|
||||
# Move batch to appropriate device
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Add position_ids to batch before CP sharding
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
batch["position_ids"] = position_ids
|
||||
from torch.distributed.tensor.experimental._attention import _cp_options
|
||||
|
||||
_cp_options.enable_load_balance = False
|
||||
|
||||
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
|
||||
cp_context = (
|
||||
nullcontext()
|
||||
if cp_mesh.size() == 1
|
||||
else context_parallel(
|
||||
cp_mesh,
|
||||
buffers=[
|
||||
batch["input_ids"],
|
||||
batch["labels"],
|
||||
batch["position_ids"],
|
||||
],
|
||||
buffer_seq_dims=[1, 1, 1],
|
||||
)
|
||||
)
|
||||
with cp_context:
|
||||
# Pop labels from batch before model forward pass
|
||||
labels = batch.pop("labels")
|
||||
outputs = model(**batch) # [mbs, seq_len/cp]
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits
|
||||
|
||||
# Compute loss with shifted labels
|
||||
loss = model.loss_function(
|
||||
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
# all reduce grads across dp_cp if applicable
|
||||
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
|
||||
|
||||
if hasattr(model, "clip_grad_norm_"):
|
||||
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) # TODO: fix reported gradnorm
|
||||
else:
|
||||
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
|
||||
assert len(list(model.parameters())) > 5, "No parameters found in model. Probably DDP bug.."
|
||||
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
|
||||
|
||||
optimizer.step()
|
||||
# allreduce loss across cp_dp before logging
|
||||
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
|
||||
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
|
||||
current_loss = loss.item()
|
||||
|
||||
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
|
||||
)
|
||||
wandb.log(
|
||||
{
|
||||
"train/loss": current_loss,
|
||||
"train/gradnorm": gradnorm,
|
||||
"step": step,
|
||||
"lr": LR,
|
||||
"GBS": global_batch_size,
|
||||
}
|
||||
)
|
||||
|
||||
step += 1 # Increment step count
|
||||
|
||||
logger.info("Training loop finished.")
|
||||
|
||||
# Save model using DCP (only if distributed)
|
||||
if dist.is_initialized():
|
||||
state_dict = {"app": AppState(model, optimizer)}
|
||||
dcp.save(
|
||||
state_dict=state_dict,
|
||||
checkpoint_id=CHECKPOINT_DIR,
|
||||
)
|
||||
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
|
||||
else:
|
||||
# Fallback to regular save for non-distributed case
|
||||
save_dir = "test_model_nondist"
|
||||
model.save_pretrained(save_dir, safe_serialization=False)
|
||||
tokenizer.save_pretrained(save_dir) # Save tokenizer too
|
||||
logger.info(f"Saved model to {save_dir}")
|
||||
|
||||
dist.destroy_process_group()
|
||||
logger.info("Cleaned up distributed process group")
|
||||
# Finish wandb run on rank 0
|
||||
if dist.get_rank() == 0:
|
||||
wandb.finish()
|
||||
logger.info("Wandb run finished.")
|
||||
|
||||
|
||||
def all_reduce_grads(model, world_mesh, use_ddp):
|
||||
"""All reduce gradients across dp_cp if applicable."""
|
||||
cp_mesh = world_mesh["cp"]
|
||||
if use_ddp:
|
||||
# DDP/FSDP takes care of syncing grads
|
||||
mesh = cp_mesh
|
||||
else:
|
||||
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
|
||||
if dist.is_initialized() and mesh.size() > 1:
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None:
|
||||
# Workaround for cross-mesh communication limitation with DTensor gradients
|
||||
if isinstance(param.grad, DTensor):
|
||||
local_grad = param.grad.to_local()
|
||||
# Ensure grad requires grad for inplace modification checks (might not be needed)
|
||||
# local_grad = local_grad.detach().requires_grad_(True)
|
||||
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
|
||||
local_grad = local_grad / mesh.size()
|
||||
# Assign averaged grad back - need careful handling if DTensor structure is complex
|
||||
# This simple assignment might work if the grad structure matches param structure
|
||||
param.grad = DTensor.from_local(
|
||||
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
|
||||
)
|
||||
else:
|
||||
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
|
||||
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
|
||||
|
||||
|
||||
class AppState(Stateful):
|
||||
"""Wrapper for checkpointing the Application State including model and optimizer."""
|
||||
|
||||
def __init__(self, model, optimizer=None):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
def state_dict(self):
|
||||
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
|
||||
return {"model": model_state_dict, "optim": optimizer_state_dict}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
set_state_dict(
|
||||
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
|
||||
)
|
||||
|
||||
|
||||
def clip_grad_norm_(
|
||||
parameters: Iterable[torch.Tensor],
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Clip the gradient norm of an iterable of parameters.
|
||||
"""
|
||||
# Filter out parameters with no gradients
|
||||
parameters = [p for p in parameters if p.grad is not None]
|
||||
assert len(parameters) > 0, "No parameters with gradients found"
|
||||
|
||||
# Calculate total norm
|
||||
if norm_type == float("inf"):
|
||||
total_norm = max(p.grad.detach().abs().max() for p in parameters)
|
||||
else:
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
|
||||
|
||||
# Convert DTensor to local tensor if needed
|
||||
if isinstance(total_norm, DTensor):
|
||||
total_norm = total_norm.full_tensor()
|
||||
|
||||
# Clip gradients
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1:
|
||||
for p in parameters:
|
||||
p.grad.detach().mul_(clip_coef)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
793
examples/pytorch/3d_parallel_checks.py
Normal file
793
examples/pytorch/3d_parallel_checks.py
Normal file
@ -0,0 +1,793 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""":
|
||||
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
|
||||
|
||||
Usage:
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
export CUDA_VISIBLE_DEVICES=5,6,7
|
||||
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
|
||||
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
|
||||
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
|
||||
|
||||
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
|
||||
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
|
||||
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
|
||||
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=l
|
||||
ocalhost:29504 test_train.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.utils.data import DataLoader, default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
ignore_sanity_checks = int(os.environ.get("IGNORE_SANITY", 0)) == 1
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
# set_rotate_method("alltoall") # rotate shards using all-to-all
|
||||
|
||||
|
||||
def main():
|
||||
tp_size = int(os.environ.get("TP_SIZE", 1))
|
||||
dp_size = int(os.environ.get("DP_SIZE", 4))
|
||||
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
|
||||
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
|
||||
# sdpa_backend = SDPBackend.MATH # For CP
|
||||
global_batch_size = 8 # Desired global batch size
|
||||
seq_len = 1024 # Sequence length
|
||||
num_train_steps = 10000 # Number of training steps
|
||||
LR = 1e-5
|
||||
model_name = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
# model_name = "unsloth/Llama-3.2-1B"
|
||||
|
||||
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
|
||||
|
||||
# Initialize distributed environment
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
dist.init_process_group("nccl")
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
assert world_size == tp_size * dp_size * cp_size, (
|
||||
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
|
||||
)
|
||||
|
||||
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
|
||||
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
|
||||
tp_mesh = world_mesh["tp"]
|
||||
dp_mesh = world_mesh["dp"]
|
||||
cp_mesh = world_mesh["cp"]
|
||||
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
|
||||
logger.info(f"Created DeviceMesh: {world_mesh}")
|
||||
logger.info(
|
||||
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
|
||||
)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
wandb.init(
|
||||
project="tp_dp_test",
|
||||
config={
|
||||
"tp_size": tp_size,
|
||||
"dp_size": dp_size,
|
||||
"cp_size": cp_size,
|
||||
"global_batch_size": global_batch_size,
|
||||
"model_name": model_name,
|
||||
"dataset": "roneneldan/TinyStories-1M",
|
||||
"seq_len": seq_len,
|
||||
"lr": LR,
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
|
||||
if model_name == "unsloth/Llama-3.2-1B"
|
||||
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
|
||||
)
|
||||
logger.info(f"ignore_sanity_checks is set to: {ignore_sanity_checks}")
|
||||
logger.info("Wandb initialized.")
|
||||
# Log the current file to wandb
|
||||
wandb.save("test_train.py")
|
||||
|
||||
else:
|
||||
logger.info("Running in non-distributed mode. DeviceMesh not applicable.")
|
||||
rank = 0
|
||||
world_size = 1
|
||||
local_rank = 0
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
wandb.init(
|
||||
project="tp_dp_test",
|
||||
config={
|
||||
"tp_size": 1,
|
||||
"dp_size": 1,
|
||||
"global_batch_size": global_batch_size,
|
||||
"model_name": model_name,
|
||||
"dataset": "roneneldan/TinyStories-1M",
|
||||
"seq_len": seq_len,
|
||||
},
|
||||
name="llama_tp1_dp1_nondist" if model_name == "unsloth/Llama-3.2-1B" else "tp1_dp1_nondist",
|
||||
)
|
||||
logger.info("Wandb initialized for non-distributed run.")
|
||||
|
||||
# Load model and tokenizer
|
||||
logger.info(f"Loading model and tokenizer from {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_mesh=tp_mesh if dist.is_initialized() else None,
|
||||
tp_plan="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
|
||||
|
||||
if dist.is_initialized():
|
||||
assert model.config.num_key_value_heads % tp_mesh.size() == 0, (
|
||||
f"num_key_value_heads={model.config.num_key_value_heads} must be divisible by tp_size={tp_mesh.size()}"
|
||||
)
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
logger.info(f"Using device: {device} for non-model tensors")
|
||||
use_ddp = False
|
||||
if dist.is_initialized() and dp_mesh.size() > 1:
|
||||
# FSDP1
|
||||
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
|
||||
# FSDP2
|
||||
# for transformer_block in model.model.layers:
|
||||
# fully_shard(transformer_block, mesh=dp_mesh, reshard_after_forward=False)
|
||||
# fully_shard(model.model, mesh=dp_mesh, reshard_after_forward=False)
|
||||
# DDP
|
||||
# replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
||||
# assert len(list(model.parameters()))>5, "No parameters found in model. Probably DDP/FSDP bug.." # TODO: we should be cautious abt using model.parameters()
|
||||
use_ddp = True
|
||||
|
||||
model.train()
|
||||
assert len(list(model.parameters())) > 0, "No parameters found in model. Probably DDP bug.."
|
||||
assert len([p for p in model.parameters() if p.requires_grad]) > 0, (
|
||||
"No gradients found in model. Probably DDP bug.."
|
||||
)
|
||||
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# assert model is replicated across all dp
|
||||
for name, param in model.named_parameters():
|
||||
sanity_check_tensor_sync(param, dp_mesh)
|
||||
|
||||
# assert model is different across tp (only for sharded params)
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(param, DTensor) and param.placements[0].is_shard():
|
||||
# Only check sharded parameters for non-sync across TP
|
||||
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
|
||||
elif isinstance(param, DTensor) and param.placements[0].is_replicate():
|
||||
# Replicated parameters should be the same across TP
|
||||
sanity_check_tensor_sync(param, tp_mesh)
|
||||
|
||||
# assert model is replicated across cp
|
||||
for name, param in model.named_parameters():
|
||||
sanity_check_tensor_sync(param, cp_mesh)
|
||||
|
||||
# Load and preprocess TinyStories dataset
|
||||
logger.info("Loading TinyStories dataset...")
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Tokenize the text without padding
|
||||
tokenized_batch = tokenizer(
|
||||
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
|
||||
)
|
||||
# Set labels to be the same as input_ids for Causal LM
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
|
||||
|
||||
# Create packed sequences
|
||||
def create_packed_sequences(examples):
|
||||
# Flatten all sequences
|
||||
all_tokens = []
|
||||
for input_ids in examples["input_ids"]:
|
||||
all_tokens.extend(input_ids)
|
||||
|
||||
# Split into sequences of seq_len + 1 (for input + label)
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
end_idx = start_idx + (seq_len + 1)
|
||||
# Get the full sequence
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
# For input_ids, remove the last token
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
# For labels, remove the first token
|
||||
packed_labels.append(full_sequence[1:])
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
|
||||
# Apply packing to the dataset
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=1000, # Process in batches for efficiency
|
||||
num_proc=60,
|
||||
)
|
||||
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
|
||||
|
||||
# Shuffle the packed dataset
|
||||
packed_dataset = packed_dataset.shuffle(seed=42)
|
||||
logger.info("Packed dataset shuffled")
|
||||
|
||||
# Calculate local batch size
|
||||
if dist.is_initialized():
|
||||
assert global_batch_size % dp_mesh.size() == 0, (
|
||||
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
|
||||
)
|
||||
local_batch_size = global_batch_size // dp_mesh.size()
|
||||
else:
|
||||
local_batch_size = global_batch_size
|
||||
|
||||
logger.info(
|
||||
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
|
||||
)
|
||||
|
||||
# Simple collate function since sequences are already packed
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
if dist.is_initialized():
|
||||
sampler = DistributedSampler(
|
||||
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
packed_dataset,
|
||||
batch_size=local_batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
|
||||
|
||||
# Training loop
|
||||
logger.info(f"Starting training for {num_train_steps} steps...")
|
||||
model.train()
|
||||
step = 0
|
||||
while step < num_train_steps:
|
||||
for batch in dataloader:
|
||||
if step >= num_train_steps:
|
||||
break # Exit loop if max steps reached
|
||||
|
||||
# Move batch to appropriate device
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
# Sanity checks for batch distribution (only if distributed)
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# check batch is same across all tp
|
||||
sanity_check_tensor_sync(batch["input_ids"], tp_mesh)
|
||||
# check batch is different across dp
|
||||
sanity_check_tensor_sync(batch["input_ids"], dp_mesh, not_sync=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Add position_ids to batch before CP sharding
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
batch["position_ids"] = position_ids
|
||||
from torch.distributed.tensor.experimental._attention import _cp_options
|
||||
|
||||
_cp_options.enable_load_balance = False
|
||||
|
||||
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
|
||||
cp_context = (
|
||||
nullcontext()
|
||||
if cp_mesh.size() == 1
|
||||
else context_parallel(
|
||||
cp_mesh,
|
||||
buffers=[
|
||||
batch["input_ids"],
|
||||
batch["labels"],
|
||||
batch["position_ids"],
|
||||
], # TODO: need to add attention mask
|
||||
buffer_seq_dims=[1, 1, 1],
|
||||
)
|
||||
)
|
||||
with cp_context:
|
||||
# Pop labels from batch before model forward pass
|
||||
labels = batch.pop("labels")
|
||||
outputs = model(**batch) # [mbs, seq_len/cp]
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits
|
||||
|
||||
# Compute loss with shifted labels
|
||||
loss = model.loss_function(
|
||||
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
|
||||
)
|
||||
|
||||
# Sanity checks for logits
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# sanity_check_tensor_sync(logits, tp_mesh) # TODO: only true without sequence parallel
|
||||
sanity_check_tensor_sync(logits, dp_mesh, not_sync=True)
|
||||
sanity_check_tensor_sync(logits, cp_mesh, not_sync=True)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# all reduce grads across dp_cp if applicable
|
||||
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
|
||||
|
||||
# Sanity checks for gradients (only if distributed)
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# check grads are not same across all tp (for sharded grads)
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and isinstance(param.grad, DTensor):
|
||||
if param.grad.placements[0].is_shard():
|
||||
sanity_check_tensor_sync(param.grad, tp_mesh, not_sync=True)
|
||||
elif param.grad.placements[0].is_replicate():
|
||||
sanity_check_tensor_sync(param.grad, tp_mesh)
|
||||
# check grads are same across dp
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and dp_mesh.size() > 1:
|
||||
sanity_check_tensor_sync(param.grad, dp_mesh)
|
||||
# check grads are same across cp
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and cp_mesh.size() > 1:
|
||||
sanity_check_tensor_sync(param.grad, cp_mesh)
|
||||
|
||||
# Calculate gradient norm and clip gradients
|
||||
if hasattr(model, "clip_grad_norm_"):
|
||||
# when using FSDP or DDP, model.parameters() doesn't work
|
||||
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0)
|
||||
else:
|
||||
assert len(list(model.parameters())) > 2, "No parameters found in model. Probably DDP bug.."
|
||||
assert len([p for p in model.parameters() if p.requires_grad]) > 2, (
|
||||
"No gradients found in model. Probably DDP bug.."
|
||||
)
|
||||
assert len([p for p in model.parameters() if p.grad is not None]) > 2, (
|
||||
"No gradients found in model. Probably DDP bug.."
|
||||
)
|
||||
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
|
||||
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
|
||||
|
||||
optimizer.step()
|
||||
# Sanity checks for updated model parameters (only if distributed)
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# check updated model is different across all tp (for sharded params)
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(param, DTensor):
|
||||
if param.placements[0].is_shard():
|
||||
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
|
||||
elif param.placements[0].is_replicate():
|
||||
sanity_check_tensor_sync(param, tp_mesh)
|
||||
# check updated model is same across dp
|
||||
for name, param in model.named_parameters():
|
||||
sanity_check_tensor_sync(param, dp_mesh)
|
||||
# check updated model is same across cp
|
||||
for name, param in model.named_parameters():
|
||||
sanity_check_tensor_sync(param, cp_mesh)
|
||||
|
||||
# allreduce loss across cp_dp before logging
|
||||
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
|
||||
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
|
||||
current_loss = loss.item()
|
||||
|
||||
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
|
||||
)
|
||||
wandb.log(
|
||||
{
|
||||
"train/loss": current_loss,
|
||||
"train/gradnorm": gradnorm,
|
||||
"step": step,
|
||||
"lr": LR,
|
||||
"GBS": global_batch_size,
|
||||
}
|
||||
)
|
||||
|
||||
step += 1 # Increment step count
|
||||
|
||||
logger.info("Training loop finished.")
|
||||
|
||||
# Save model using DCP (only if distributed)
|
||||
if dist.is_initialized():
|
||||
state_dict = {"app": AppState(model, optimizer)}
|
||||
dcp.save(
|
||||
state_dict=state_dict,
|
||||
checkpoint_id=CHECKPOINT_DIR,
|
||||
)
|
||||
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
|
||||
else:
|
||||
# Fallback to regular save for non-distributed case
|
||||
save_dir = "test_model_nondist"
|
||||
model.save_pretrained(save_dir, safe_serialization=False)
|
||||
tokenizer.save_pretrained(save_dir) # Save tokenizer too
|
||||
logger.info(f"Saved model to {save_dir}")
|
||||
|
||||
# Example of loading the checkpoint (only if distributed)
|
||||
if dist.is_initialized():
|
||||
# Create a new model instance
|
||||
logger.info("Creating new model instance for verification")
|
||||
new_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_mesh=tp_mesh,
|
||||
torch_dtype=torch.bfloat16, # Use same dtype
|
||||
)
|
||||
new_optimizer = optim.AdamW(new_model.parameters(), lr=LR)
|
||||
|
||||
# Load checkpoint into new model
|
||||
state_dict = {"app": AppState(new_model, new_optimizer)}
|
||||
dcp.load(
|
||||
state_dict=state_dict,
|
||||
checkpoint_id=CHECKPOINT_DIR,
|
||||
)
|
||||
logger.info("Loaded checkpoint into new model")
|
||||
|
||||
# Verify model weights match
|
||||
logger.info("Verifying model weights match...")
|
||||
for (name1, param1), (name2, param2) in zip(model.named_parameters(), new_model.named_parameters()):
|
||||
torch.testing.assert_close(
|
||||
param1.to_local(),
|
||||
param2.to_local(),
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
msg=f"Weights mismatch in {name1} vs {name2}",
|
||||
)
|
||||
|
||||
# Verify optimizer states match
|
||||
logger.info("Verifying optimizer states match...")
|
||||
for name1, state1 in optimizer.state_dict().items():
|
||||
state2 = new_optimizer.state_dict()[name1]
|
||||
if name1 == "state":
|
||||
# Compare state dictionaries for each parameter
|
||||
for param_id, param_state1 in state1.items():
|
||||
param_state2 = state2[param_id]
|
||||
# Compare each state component (step, exp_avg, exp_avg_sq)
|
||||
for key, value1 in param_state1.items():
|
||||
value2 = param_state2[key]
|
||||
if isinstance(value1, DTensor):
|
||||
# Convert DTensors to local tensors for comparison
|
||||
torch.testing.assert_close(
|
||||
value1.to_local(),
|
||||
value2.to_local(),
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
|
||||
)
|
||||
else:
|
||||
torch.testing.assert_close(
|
||||
value1,
|
||||
value2,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
|
||||
)
|
||||
elif name1 == "param_groups":
|
||||
# Compare param_groups (excluding the actual params list)
|
||||
for i, (group1, group2) in enumerate(zip(state1, state2)):
|
||||
for key in group1:
|
||||
if key != "params": # Skip comparing the params list
|
||||
assert group1[key] == group2[key], f"Param group mismatch in param_groups[{i}][{key}]"
|
||||
|
||||
# Run a forward pass with both models to verify outputs match
|
||||
logger.info("Running forward pass verification...")
|
||||
with torch.no_grad():
|
||||
# Use the last batch for verification
|
||||
batch = {k: v.to(device) for k, v in batch.items()} # Ensure batch is on correct device
|
||||
original_outputs = model(**batch)
|
||||
new_outputs = new_model(**batch)
|
||||
torch.testing.assert_close(
|
||||
original_outputs.logits.to_local(),
|
||||
new_outputs.logits.to_local(),
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
msg="Model outputs do not match!",
|
||||
) # Increased tolerance slightly for bf16
|
||||
|
||||
# Clean up distributed environment and finish wandb run
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
logger.info("Cleaned up distributed process group")
|
||||
# Finish wandb run on rank 0
|
||||
if dist.get_rank() == 0:
|
||||
wandb.finish()
|
||||
logger.info("Wandb run finished.")
|
||||
else:
|
||||
wandb.finish()
|
||||
logger.info("Wandb run finished.")
|
||||
|
||||
|
||||
def all_reduce_grads(model, world_mesh, use_ddp):
|
||||
"""All reduce gradients across dp_cp if applicable."""
|
||||
cp_mesh = world_mesh["cp"]
|
||||
if use_ddp:
|
||||
# DDP takes care of syncing grads
|
||||
mesh = cp_mesh
|
||||
else:
|
||||
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
|
||||
if dist.is_initialized() and mesh.size() > 1:
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None:
|
||||
# Workaround for cross-mesh communication limitation with DTensor gradients
|
||||
if isinstance(param.grad, DTensor):
|
||||
local_grad = param.grad.to_local()
|
||||
# Ensure grad requires grad for inplace modification checks (might not be needed)
|
||||
# local_grad = local_grad.detach().requires_grad_(True)
|
||||
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
|
||||
local_grad = local_grad / mesh.size()
|
||||
# Assign averaged grad back - need careful handling if DTensor structure is complex
|
||||
# This simple assignment might work if the grad structure matches param structure
|
||||
param.grad = DTensor.from_local(
|
||||
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
|
||||
)
|
||||
else:
|
||||
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
|
||||
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
|
||||
|
||||
|
||||
class ContextParallelCollator:
|
||||
"""Collator for context parallel training that splits sequences into chunks."""
|
||||
|
||||
def __init__(self, cp_mesh: Optional[DeviceMesh] = None):
|
||||
self.cp_mesh = cp_mesh
|
||||
|
||||
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
batch = default_collate(batch)
|
||||
if self.cp_mesh is not None and self.cp_mesh.size() > 1:
|
||||
# Get sequence length from the input batch
|
||||
seq_len = batch["input_ids"].shape[1]
|
||||
assert seq_len % self.cp_mesh.size() == 0, (
|
||||
f"Sequence length {seq_len} must be divisible by CP size {self.cp_mesh.size()}"
|
||||
)
|
||||
chunk_size = seq_len // self.cp_mesh.size()
|
||||
cp_rank = self.cp_mesh.get_local_rank()
|
||||
start_idx = cp_rank * chunk_size
|
||||
end_idx = start_idx + chunk_size
|
||||
|
||||
# Keep only the local chunk of the sequence
|
||||
batch["input_ids"] = batch["input_ids"][:, start_idx:end_idx]
|
||||
batch["attention_mask"] = batch["attention_mask"][:, start_idx:end_idx]
|
||||
batch["labels"] = batch["labels"][:, start_idx:end_idx]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class AppState(Stateful):
|
||||
"""Wrapper for checkpointing the Application State including model and optimizer."""
|
||||
|
||||
def __init__(self, model, optimizer=None):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
def state_dict(self):
|
||||
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
|
||||
return {"model": model_state_dict, "optim": optimizer_state_dict}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
set_state_dict(
|
||||
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
|
||||
)
|
||||
|
||||
|
||||
def sanity_check_tensor_sync(
|
||||
tensor: torch.Tensor, mesh: DeviceMesh, rtol: float = 1e-4, atol: float = 1e-4, not_sync: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Verify that a tensor is synchronized (or not synchronized) across all processes in the mesh's process group.
|
||||
Handles both regular tensors and DTensors.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to check for synchronization (can be DTensor)
|
||||
mesh (DeviceMesh): The device mesh containing the process group
|
||||
rtol (float): Relative tolerance for comparison
|
||||
atol (float): Absolute tolerance for comparison
|
||||
not_sync (bool): If True, asserts that tensors are NOT synchronized. If False, asserts they are synchronized.
|
||||
"""
|
||||
if not dist.is_initialized() or mesh.size() == 1:
|
||||
return # No need to check in non-distributed mode
|
||||
|
||||
# Get the process group from the mesh
|
||||
pg = mesh.get_group()
|
||||
|
||||
# Convert DTensor to local tensor if needed
|
||||
if hasattr(tensor, "to_local"):
|
||||
local_tensor = tensor.to_local()
|
||||
else:
|
||||
local_tensor = tensor
|
||||
|
||||
# Gather tensors from all processes
|
||||
world_size = dist.get_world_size(pg)
|
||||
gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_tensors, local_tensor, group=pg)
|
||||
|
||||
# Compare each tensor with the first one
|
||||
for i in range(1, world_size):
|
||||
try:
|
||||
torch.testing.assert_close(gathered_tensors[0], gathered_tensors[i], rtol=rtol, atol=atol)
|
||||
except AssertionError as e:
|
||||
if not_sync:
|
||||
continue
|
||||
# # Add detailed debugging for logit synchronization issues
|
||||
# print(f"\nLogit synchronization error between rank 0 and rank {i}:")
|
||||
# print(f"Tensor shape: {gathered_tensors[0].shape}")
|
||||
# print(f"Number of mismatched elements: {(gathered_tensors[0] != gathered_tensors[i]).sum()}")
|
||||
# print(f"Percentage of mismatched elements: {((gathered_tensors[0] != gathered_tensors[i]).sum() / gathered_tensors[0].numel() * 100):.2f}%")
|
||||
|
||||
# # Find the first few mismatches
|
||||
# mismatches = torch.nonzero(gathered_tensors[0] != gathered_tensors[i])
|
||||
# print("\nFirst few mismatches:")
|
||||
# for idx in mismatches[:5]:
|
||||
# idx = tuple(idx.tolist())
|
||||
# print(f"Index {idx}:")
|
||||
# print(f"Rank 0 value: {gathered_tensors[0][idx]}")
|
||||
# print(f"Rank {i} value: {gathered_tensors[i][idx]}")
|
||||
# print(f"Absolute difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx])}")
|
||||
# print(f"Relative difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx]) / max(abs(gathered_tensors[0][idx]), abs(gathered_tensors[i][idx]))}")
|
||||
|
||||
# # Check if differences are systematic (e.g., all positive or negative)
|
||||
# diff = gathered_tensors[0] - gathered_tensors[i]
|
||||
# print(f"\nDifference statistics:")
|
||||
# print(f"Mean difference: {diff.mean()}")
|
||||
# print(f"Std difference: {diff.std()}")
|
||||
# print(f"Max positive difference: {diff.max()}")
|
||||
# print(f"Max negative difference: {diff.min()}")
|
||||
raise e
|
||||
|
||||
|
||||
def clip_grad_norm_(
|
||||
parameters: Iterable[torch.Tensor],
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Clip the gradient norm of an iterable of parameters.
|
||||
"""
|
||||
# Filter out parameters with no gradients
|
||||
parameters = [p for p in parameters if p.grad is not None]
|
||||
assert len(parameters) > 0, "No parameters with gradients found"
|
||||
|
||||
# Calculate total norm
|
||||
if norm_type == float("inf"):
|
||||
total_norm = max(p.grad.detach().abs().max() for p in parameters)
|
||||
else:
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
|
||||
|
||||
# Convert DTensor to local tensor if needed
|
||||
if isinstance(total_norm, DTensor):
|
||||
total_norm = total_norm.full_tensor()
|
||||
|
||||
# Clip gradients
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1:
|
||||
for p in parameters:
|
||||
p.grad.detach().mul_(clip_coef)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def check_params_sync(model_params, original_params):
|
||||
"""
|
||||
Check if original_params are being updated in sync with model parameters.
|
||||
|
||||
Args:
|
||||
model_params: Iterator of model parameters after update
|
||||
original_params: List of original parameters before DDP wrapping
|
||||
"""
|
||||
for mp, op in zip(model_params, original_params):
|
||||
if isinstance(mp, DTensor):
|
||||
mp = mp.to_local()
|
||||
if isinstance(op, DTensor):
|
||||
op = op.to_local()
|
||||
if not torch.allclose(mp.data, op.data, rtol=0, atol=0):
|
||||
raise RuntimeError(f"Parameters out of sync: model param {mp.data} != original param {op.data}")
|
||||
return True
|
||||
|
||||
|
||||
def get_parameters(model: nn.Module) -> Iterable[torch.Tensor]:
|
||||
"""
|
||||
Get all parameters from a model by iterating over its modules.
|
||||
This is an alternative to model.parameters() that works with DTensor models.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to get parameters from
|
||||
|
||||
Returns:
|
||||
Iterable[torch.Tensor]: An iterator over all parameters in the model
|
||||
"""
|
||||
for name, module in model._modules.items():
|
||||
# Look for parameters in module attributes
|
||||
for attr_name, attr in module.__dict__.items():
|
||||
if isinstance(attr, torch.Tensor) and attr.requires_grad:
|
||||
yield attr
|
||||
# Recursively get parameters from submodules
|
||||
for param in get_parameters(module):
|
||||
yield param
|
||||
|
||||
|
||||
def update_model_parameters(model: nn.Module) -> None:
|
||||
"""
|
||||
Update model._parameters using named_modules() to ensure all parameters are properly tracked.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to update parameters for
|
||||
"""
|
||||
# Clear existing parameters
|
||||
model._parameters = {}
|
||||
|
||||
# Add parameters from named_modules
|
||||
for name, module in model.named_modules():
|
||||
# Skip the root module itself
|
||||
if name == "":
|
||||
continue
|
||||
|
||||
# Get the parameter name by removing 'module.' prefix if it exists
|
||||
param_name = name.replace("module.", "")
|
||||
|
||||
# Add weight and bias parameters if they exist
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
model._parameters[f"{param_name}.weight"] = module.weight
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
model._parameters[f"{param_name}.bias"] = module.bias
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -44,7 +44,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
94
examples/pytorch/context_parallel.py
Normal file
94
examples/pytorch/context_parallel.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.loss.loss_utils import ForCausalLMLoss
|
||||
|
||||
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
cp_mesh = init_device_mesh("cuda", (world_size,))
|
||||
rank = torch.distributed.get_node_local_rank()
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
sdpa_backend = SDPBackend.FLASH_ATTENTION
|
||||
|
||||
# prepare inputs
|
||||
batch_size = 1
|
||||
seq_len = 128
|
||||
|
||||
input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device)
|
||||
|
||||
ignore_index = -100
|
||||
# When using CP, we need to use `shift_labels`
|
||||
shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index)
|
||||
shift_labels = shift_labels[..., 1:].contiguous()
|
||||
|
||||
position_ids = (
|
||||
torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1
|
||||
)
|
||||
|
||||
# sync input as they are created randomly
|
||||
dist.broadcast(input_ids, src=0)
|
||||
dist.broadcast(shift_labels, src=0)
|
||||
dist.broadcast(position_ids, src=0)
|
||||
|
||||
# model and optimizer
|
||||
repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
||||
|
||||
model.train()
|
||||
model.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# For loss
|
||||
vocab_size = model.config.vocab_size
|
||||
|
||||
# so training could be synced
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
# prepare for CP
|
||||
buffers = (input_ids, shift_labels, position_ids)
|
||||
buffer_seq_dims = (1, 1, 1)
|
||||
# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`.
|
||||
# no_restore_buffers = set(buffers)
|
||||
no_restore_buffers = None
|
||||
|
||||
# run with CP
|
||||
with sdpa_kernel(sdpa_backend):
|
||||
with context_parallel(
|
||||
cp_mesh,
|
||||
buffers=buffers,
|
||||
buffer_seq_dims=buffer_seq_dims,
|
||||
no_restore_buffers=no_restore_buffers,
|
||||
):
|
||||
outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)
|
||||
print(outputs.logits.shape)
|
||||
|
||||
# So far we need to compute `loss` outside `model.forward` when using `shift_labels`
|
||||
# loss = outputs.loss
|
||||
loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)
|
||||
|
||||
# This could be outside `context_parallel` context if `no_restore_buffers` is specified
|
||||
loss.backward()
|
||||
optimizer.step()
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -42,7 +42,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -451,7 +451,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.52.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.53.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.52.0.dev0"
|
||||
__version__ = "4.53.0.dev0"
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -1985,7 +1985,9 @@ class GenerationMixin:
|
||||
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
||||
"""
|
||||
|
||||
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
||||
is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
|
||||
cache_name = "past_key_values" if not is_hybrid_cache else "cache_params"
|
||||
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
)
|
||||
|
@ -142,7 +142,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["tensor_parallel"] = [
|
||||
"shard_and_distribute_module",
|
||||
"SUPPORTED_TP_STYLES",
|
||||
"ALL_PARALLEL_STYLES",
|
||||
"translate_to_torch_parallel_style",
|
||||
]
|
||||
try:
|
||||
@ -271,7 +271,7 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
else:
|
||||
from .tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
ALL_PARALLEL_STYLES,
|
||||
shard_and_distribute_module,
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
|
@ -362,8 +362,8 @@ def _replace_with_bitnet_linear(
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype,
|
||||
use_rms_norm=quantization_config.use_rms_norm,
|
||||
rms_norm_eps=quantization_config.rms_norm_eps,
|
||||
use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
|
||||
rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
|
||||
)
|
||||
model._modules[name].requires_grad_(False)
|
||||
has_been_replaced = True
|
||||
|
@ -13,11 +13,15 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
from functools import lru_cache, partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from collections.abc import MutableMapping
|
||||
from functools import partial, reduce
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_greater_or_equal, logging
|
||||
@ -35,6 +39,56 @@ if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
|
||||
|
||||
|
||||
def initialize_tensor_parallelism(tp_plan, tp_size=None):
|
||||
r"""
|
||||
Sets up the device mesh and initilized the backend for tensor parallelism.
|
||||
This function is called when the model is loaded and the TP plan is set to 'auto'.
|
||||
"""
|
||||
if tp_plan is None:
|
||||
return None, None, None
|
||||
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise EnvironmentError("Tensor parallel is only supported for `torch>=2.5`.")
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
if not torch.distributed.is_initialized():
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "ccl", "hpu": "hccl"}
|
||||
backend = backend_map.get(device_type)
|
||||
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", 0)):
|
||||
backend = "ccl"
|
||||
|
||||
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
current_device = getattr(torch, device_type)
|
||||
if device_type != "cpu":
|
||||
current_device.set_device(local_rank)
|
||||
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"We tried to initialize torch.distributed for you, but it failed. Make "
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`."
|
||||
) from e
|
||||
index = current_device.current_device() if device_type != "cpu" else None
|
||||
tp_device = torch.device(device_type, index)
|
||||
|
||||
# Silence output for non-primary ranks
|
||||
if index is not None and index > 0:
|
||||
import sys
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
|
||||
device_map = tp_device
|
||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
||||
return tp_device, device_map, device_mesh
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
"""
|
||||
Convert block count or proportions to block sizes.
|
||||
@ -220,18 +274,38 @@ def repack_weights(
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
if dim == 0:
|
||||
size_ = empty_param.shape[0]
|
||||
param = param[rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), ...]
|
||||
elif dim == 1 or dim == -2:
|
||||
size_ = empty_param.shape[-2]
|
||||
param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), :]
|
||||
elif dim == 2 or dim == -1:
|
||||
size_ = empty_param.shape[-1]
|
||||
param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size())]
|
||||
else:
|
||||
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
||||
return param
|
||||
"""
|
||||
Generalized tensor sharding across a multi-dimensional device mesh.
|
||||
|
||||
Args:
|
||||
param (torch.Tensor): The tensor to shard.
|
||||
empty_param (torch.Tensor): A tensor used for shape reference.
|
||||
device_mesh (torch.Tensor): Shape [d_0, ..., d_n] representing the mesh.
|
||||
rank (int): Global rank of the current process/device.
|
||||
dim (int): Dimension along which to shard the tensor.
|
||||
"""
|
||||
param_dim = empty_param.dim()
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if dim >= param_dim:
|
||||
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
|
||||
|
||||
# Flatten the mesh to get the total number of devices
|
||||
mesh_shape = device_mesh.shape
|
||||
world_size = reduce(operator.mul, mesh_shape)
|
||||
|
||||
if rank >= world_size:
|
||||
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
||||
|
||||
shard_size = empty_param.shape[dim] // world_size
|
||||
start = rank * shard_size
|
||||
end = start + shard_size
|
||||
|
||||
# Construct slicing index dynamically
|
||||
slice_indices = [slice(None)] * param_dim
|
||||
slice_indices[dim] = slice(start, end)
|
||||
|
||||
return param[tuple(slice_indices)]
|
||||
|
||||
|
||||
def distribute_module(
|
||||
@ -339,6 +413,41 @@ class IsolatedParallel(TensorParallelLayer):
|
||||
)
|
||||
|
||||
|
||||
class ReplicateParallel(TensorParallelLayer):
|
||||
"""
|
||||
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
|
||||
"""
|
||||
|
||||
def __init__(self, *, use_dtensor=True, use_local_output=True):
|
||||
super().__init__()
|
||||
self.input_layouts = (Replicate(),)
|
||||
self.output_layouts = (Replicate(),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
# TODO: figure out dynamo support for instance method and switch this to instance method
|
||||
# annotate module input placements/sharding with input_layouts
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
|
||||
return input_tensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
return outputs.to_local() if use_local_output else outputs
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
param = param.contiguous()
|
||||
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
|
||||
return param
|
||||
|
||||
|
||||
class ColwiseParallel(TensorParallelLayer):
|
||||
"""
|
||||
General tensor parallel layer for transformers.
|
||||
@ -611,52 +720,62 @@ class SequenceParallel(TensorParallelLayer):
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
|
||||
SUPPORTED_TP_STYLES = {
|
||||
"colwise",
|
||||
"rowwise",
|
||||
"colwise_rep",
|
||||
"rowwise_rep",
|
||||
"local_colwise",
|
||||
"local_rowwise",
|
||||
"local",
|
||||
"gather",
|
||||
"local_packed_rowwise",
|
||||
"sequence_parallel",
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def translate_to_torch_parallel_style(style: str):
|
||||
class ParallelInterface(MutableMapping):
|
||||
"""
|
||||
In model configurations, we use a neutral type (string) to specify parallel
|
||||
styles, here we translate them into torch.distributed tensor-parallel
|
||||
types.
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
|
||||
"""
|
||||
if not isinstance(style, str):
|
||||
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
||||
|
||||
if style == "colwise":
|
||||
return ColwiseParallel()
|
||||
elif style == "rowwise":
|
||||
return RowwiseParallel()
|
||||
elif style == "colwise_rep":
|
||||
return ColwiseParallel(output_layouts=Replicate())
|
||||
elif style == "rowwise_rep":
|
||||
return RowwiseParallel(input_layouts=Replicate())
|
||||
elif style == "local_colwise":
|
||||
return ColwiseParallel(use_dtensor=False)
|
||||
elif style == "local_rowwise":
|
||||
return RowwiseParallel(use_dtensor=False)
|
||||
elif style == "local":
|
||||
return IsolatedParallel()
|
||||
elif style == "gather":
|
||||
return GatherParallel()
|
||||
elif style == "local_packed_rowwise":
|
||||
return PackedRowwiseParallel(use_dtensor=False)
|
||||
elif style == "sequence_parallel":
|
||||
return SequenceParallel()
|
||||
else:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||
# a new instance is created (in order to locally override a given function)
|
||||
_global_mapping = {
|
||||
"colwise": ColwiseParallel(),
|
||||
"rowwise": RowwiseParallel(),
|
||||
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
|
||||
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
|
||||
"local_colwise": ColwiseParallel(use_dtensor=False),
|
||||
"local_rowwise": RowwiseParallel(use_dtensor=False),
|
||||
"local": IsolatedParallel(),
|
||||
"gather": GatherParallel(),
|
||||
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
|
||||
"sequence_parallel": SequenceParallel(),
|
||||
"replicate": ReplicateParallel(),
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._local_mapping = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
# First check if instance has a local override
|
||||
if key in self._local_mapping:
|
||||
return self._local_mapping[key]
|
||||
return self._global_mapping[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Allow local update of the default functions without impacting other instances
|
||||
self._local_mapping.update({key: value})
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._local_mapping[key]
|
||||
|
||||
def __iter__(self):
|
||||
# Ensure we use all keys, with the overwritten ones on top
|
||||
return iter({**self._global_mapping, **self._local_mapping})
|
||||
|
||||
def __len__(self):
|
||||
return len(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
|
||||
@classmethod
|
||||
def register(cls, key: str, value: Callable):
|
||||
cls._global_mapping.update({key: value})
|
||||
|
||||
def valid_keys(self) -> List[str]:
|
||||
return list(self.keys())
|
||||
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
|
||||
|
||||
|
||||
def convert_local_tensor_to_dtensor(
|
||||
@ -722,13 +841,15 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr
|
||||
|
||||
# 1. We add hooks to the layer being loaded:
|
||||
if current_module_plan is not None:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
|
||||
try:
|
||||
tp_layer.prepare_module_tp(module, device_mesh)
|
||||
except NotImplementedError as e:
|
||||
print(
|
||||
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
|
||||
)
|
||||
module._hf_tp_plan = current_module_plan
|
||||
module.__repr__ = lambda: f"{module.__repr__()}\nTP Plan: {current_module_plan}"
|
||||
|
||||
# 2. We add hooks to the parent module if needed
|
||||
if "." in layer_name:
|
||||
@ -736,9 +857,11 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr
|
||||
generic_name = re.sub(r"\d+", "*", parent_layer_name)
|
||||
# The module itself needs hooks
|
||||
if module_plan := tp_plan.get(generic_name, False):
|
||||
tp_layer = translate_to_torch_parallel_style(module_plan)
|
||||
tp_layer = ALL_PARALLEL_STYLES[module_plan]
|
||||
module_to_tp_ = model.get_submodule(parent_layer_name)
|
||||
tp_layer.prepare_module_tp(module_to_tp_, device_mesh)
|
||||
module_to_tp_._hf_tp_plan = current_module_plan
|
||||
module_to_tp_.__repr__ = lambda: f"{module_to_tp_.__repr__()}\nTP Plan: {current_module_plan}"
|
||||
|
||||
|
||||
def shard_and_distribute_module(
|
||||
@ -760,28 +883,29 @@ def shard_and_distribute_module(
|
||||
|
||||
current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan)
|
||||
|
||||
if current_module_plan is None:
|
||||
current_module_plan = "replicate"
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f"Tensor parallel plan for {param_name} not found, using default 'replicate' plan.")
|
||||
else:
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f"Tensor parallel plan for {param_name}: {current_module_plan}")
|
||||
|
||||
# Add hooks to the module if not done yet
|
||||
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
||||
if not getattr(module_to_tp, "_is_hooked", False):
|
||||
add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
||||
module_to_tp._is_hooked = True
|
||||
|
||||
if current_module_plan is not None:
|
||||
try:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
print(
|
||||
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
||||
)
|
||||
else:
|
||||
# TODO log no plan modules in set
|
||||
# print("No plan for", parameter_name,end ="\n")
|
||||
param = param[...].to(param_casting_dtype)
|
||||
if is_contiguous:
|
||||
param = param.contiguous()
|
||||
try:
|
||||
tp_layer = ALL_PARALLEL_STYLES[current_module_plan]
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
print(
|
||||
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
|
||||
)
|
||||
|
||||
# SUPER IMPORTANT we have to use setattr
|
||||
# otherwise loading is crazy slow
|
||||
|
@ -66,7 +66,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation"
|
||||
" instructions."
|
||||
)
|
||||
raise
|
||||
@ -360,7 +360,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(
|
||||
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
||||
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation"
|
||||
" instructions."
|
||||
)
|
||||
raise
|
||||
|
@ -62,8 +62,9 @@ from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
ALL_PARALLEL_STYLES,
|
||||
_get_parameter_tp_plan,
|
||||
initialize_tensor_parallelism,
|
||||
repack_weights,
|
||||
replace_state_dict_local_with_dtensor,
|
||||
shard_and_distribute_module,
|
||||
@ -797,7 +798,7 @@ def _load_state_dict_into_meta_model(
|
||||
param_name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
int(os.environ["RANK"]), # the rank
|
||||
device_mesh.get_local_rank(),
|
||||
device_mesh,
|
||||
)
|
||||
else:
|
||||
@ -1964,9 +1965,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
|
||||
for _, v in self._tp_plan.items():
|
||||
if v not in SUPPORTED_TP_STYLES:
|
||||
if v not in ALL_PARALLEL_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {ALL_PARALLEL_STYLES}"
|
||||
)
|
||||
|
||||
def dequantize(self):
|
||||
@ -3559,6 +3560,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
|
||||
|
||||
if safe_serialization:
|
||||
# TODO: fix safe_serialization for tied weights
|
||||
# Safetensors does not allow tensor aliasing.
|
||||
# We're going to remove aliases before saving
|
||||
ptrs = collections.defaultdict(list)
|
||||
@ -4040,6 +4042,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
|
||||
tp_size (`str`, *optional*):
|
||||
A torch tensor parallel degree. If not provided would default to world size.
|
||||
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
|
||||
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
@ -4137,6 +4141,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
gguf_file = kwargs.pop("gguf_file", None)
|
||||
tp_plan = kwargs.pop("tp_plan", None)
|
||||
tp_size = kwargs.pop("tp_size", None)
|
||||
device_mesh = kwargs.pop("device_mesh", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
||||
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
@ -4172,59 +4177,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
|
||||
# `device_map` pointing to the correct device
|
||||
device_mesh = None
|
||||
if tp_plan is not None:
|
||||
if not is_torch_greater_or_equal("2.5"):
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||
device_type = torch._C._get_accelerator().type
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
try:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
if device_type == "cuda":
|
||||
torch.distributed.init_process_group(
|
||||
"nccl", rank=rank, world_size=world_size, init_method="env://"
|
||||
)
|
||||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
elif device_type == "cpu":
|
||||
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
|
||||
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
|
||||
elif device_type == "xpu":
|
||||
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
|
||||
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
elif device_type == "hpu":
|
||||
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
|
||||
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
except Exception as e:
|
||||
raise EnvironmentError(
|
||||
"We tried to initialize torch.distributed for you, but it failed, make"
|
||||
"sure you init torch distributed in your script to use `tp_plan='auto'`"
|
||||
) from e
|
||||
|
||||
# Get device with index assuming equal number of devices per host
|
||||
if device_type == "xpu":
|
||||
index = torch.xpu.current_device()
|
||||
elif device_type == "hpu":
|
||||
index = torch.hpu.current_device()
|
||||
if device_mesh is None and tp_plan is not None:
|
||||
tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None)
|
||||
else:
|
||||
index = None if device_type == "cpu" else torch.cuda.current_device()
|
||||
tp_device = torch.device(device_type, index)
|
||||
|
||||
if index is not None and index > 0:
|
||||
import sys
|
||||
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
# This is the easiest way to dispatch to the current process device
|
||||
device_map = tp_device
|
||||
|
||||
# Assuming sharding the model onto the world when tp_size not provided
|
||||
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
|
||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
|
||||
# TODO: make device_mesh support multiple dimensions
|
||||
if device_mesh.ndim == 1:
|
||||
raise ValueError("device_mesh must be 1 dimensional and will be used for TP")
|
||||
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@ -5142,7 +5102,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
name,
|
||||
casting_dtype,
|
||||
to_contiguous,
|
||||
os.environ["RANK"],
|
||||
device_mesh.get_local_rank(),
|
||||
device_mesh,
|
||||
)
|
||||
|
||||
|
@ -103,6 +103,7 @@ if TYPE_CHECKING:
|
||||
from .ernie import *
|
||||
from .esm import *
|
||||
from .falcon import *
|
||||
from .falcon_h1 import *
|
||||
from .falcon_mamba import *
|
||||
from .fastspeech2_conformer import *
|
||||
from .flaubert import *
|
||||
|
@ -364,19 +364,23 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
|
||||
return resized_image
|
||||
|
||||
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
|
||||
original_height, original_width = original_resolution
|
||||
target_height, target_width = target_resolution
|
||||
paste_x, r_x = divmod(target_width - original_width, 2)
|
||||
paste_y, r_y = divmod(target_height - original_height, 2)
|
||||
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> np.array:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
padding = self._get_padding_size(new_resolution, target_resolution)
|
||||
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
|
||||
padded_image = self.pad(image, padding=padding)
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -748,19 +748,23 @@ class AriaImageProcessor(BaseImageProcessor):
|
||||
|
||||
return resized_image
|
||||
|
||||
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
|
||||
original_height, original_width = original_resolution
|
||||
target_height, target_width = target_resolution
|
||||
paste_x, r_x = divmod(target_width - original_width, 2)
|
||||
paste_y, r_y = divmod(target_height - original_height, 2)
|
||||
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> np.array:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
padding = self._get_padding_size(new_resolution, target_resolution)
|
||||
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
|
||||
padded_image = self.pad(image, padding=padding)
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -118,6 +118,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("ernie_m", "ErnieMConfig"),
|
||||
("esm", "EsmConfig"),
|
||||
("falcon", "FalconConfig"),
|
||||
("falcon_h1", "FalconH1Config"),
|
||||
("falcon_mamba", "FalconMambaConfig"),
|
||||
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
|
||||
("flaubert", "FlaubertConfig"),
|
||||
@ -481,6 +482,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("esm", "ESM"),
|
||||
("falcon", "Falcon"),
|
||||
("falcon3", "Falcon3"),
|
||||
("falcon_h1", "FalconH1"),
|
||||
("falcon_mamba", "FalconMamba"),
|
||||
("fastspeech2_conformer", "FastSpeech2Conformer"),
|
||||
("flan-t5", "FLAN-T5"),
|
||||
|
@ -115,6 +115,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("ernie_m", "ErnieMModel"),
|
||||
("esm", "EsmModel"),
|
||||
("falcon", "FalconModel"),
|
||||
("falcon_h1", "FalconH1Model"),
|
||||
("falcon_mamba", "FalconMambaModel"),
|
||||
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
|
||||
("flaubert", "FlaubertModel"),
|
||||
@ -558,6 +559,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("emu3", "Emu3ForCausalLM"),
|
||||
("ernie", "ErnieForCausalLM"),
|
||||
("falcon", "FalconForCausalLM"),
|
||||
("falcon_h1", "FalconH1ForCausalLM"),
|
||||
("falcon_mamba", "FalconMambaForCausalLM"),
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("gemma", "GemmaForCausalLM"),
|
||||
|
@ -24,7 +24,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -61,6 +62,31 @@ else:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BambaFlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||
|
||||
Attributes:
|
||||
cu_seq_lens_q (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for query state.
|
||||
cu_seq_lens_k (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`):
|
||||
Maximum sequence length for key state.
|
||||
seq_idx (`torch.IntTensor):
|
||||
Index of each packed sequence.
|
||||
"""
|
||||
|
||||
cu_seq_lens_q: torch.LongTensor
|
||||
cu_seq_lens_k: torch.LongTensor
|
||||
max_length_q: int
|
||||
max_length_k: int
|
||||
seq_idx: torch.IntTensor
|
||||
|
||||
|
||||
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
|
||||
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
|
||||
"""
|
||||
@ -487,6 +513,7 @@ class BambaMixer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
@ -569,7 +596,7 @@ class BambaMixer(nn.Module):
|
||||
A,
|
||||
D=self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=None, # was seq_idx
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight,
|
||||
rmsnorm_eps=self.norm.variance_epsilon,
|
||||
@ -610,6 +637,7 @@ class BambaMixer(nn.Module):
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=seq_idx,
|
||||
).transpose(1, 2)
|
||||
|
||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||
@ -629,7 +657,7 @@ class BambaMixer(nn.Module):
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
seq_idx=seq_idx,
|
||||
return_final_states=True,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
@ -863,9 +891,15 @@ class BambaMixer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||
if seq_idx is not None:
|
||||
raise NotImplementedError(
|
||||
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||
)
|
||||
dtype = hidden_states.dtype
|
||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||
@ -939,7 +973,7 @@ class BambaDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -959,8 +993,8 @@ class BambaDecoderLayer(nn.Module):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
|
||||
padding-free training and/or improve torch.compile performance.
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
@ -974,6 +1008,7 @@ class BambaDecoderLayer(nn.Module):
|
||||
cache_params=past_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
self_attn_weights = None
|
||||
elif self.layer_type == "attention":
|
||||
@ -1076,7 +1111,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -1128,7 +1163,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
partial(decoder_layer.__call__, **kwargs),
|
||||
hidden_states,
|
||||
layer_mask,
|
||||
position_ids,
|
||||
@ -1148,6 +1183,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -19,7 +19,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Bamba model."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -46,7 +47,12 @@ from transformers.models.mamba2.modeling_mamba2 import (
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ...utils.import_utils import is_causal_conv1d_available, is_flash_attn_2_available, is_mamba_2_ssm_available
|
||||
from .configuration_bamba import BambaConfig
|
||||
|
||||
@ -71,6 +77,31 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_c
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class BambaFlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
|
||||
Use cases include padding-free training and fewer `torch.compile` graph breaks.
|
||||
|
||||
Attributes:
|
||||
cu_seq_lens_q (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for query state.
|
||||
cu_seq_lens_k (`torch.LongTensor`)
|
||||
Gets cumulative sequence length for key state.
|
||||
max_length_q (`int`):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`):
|
||||
Maximum sequence length for key state.
|
||||
seq_idx (`torch.IntTensor):
|
||||
Index of each packed sequence.
|
||||
"""
|
||||
|
||||
cu_seq_lens_q: torch.LongTensor
|
||||
cu_seq_lens_k: torch.LongTensor
|
||||
max_length_q: int
|
||||
max_length_k: int
|
||||
seq_idx: torch.IntTensor
|
||||
|
||||
|
||||
# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
|
||||
class HybridMambaAttentionDynamicCache(modeling_jamba.HybridMambaAttentionDynamicCache):
|
||||
"""
|
||||
@ -278,6 +309,7 @@ class BambaMixer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
@ -360,7 +392,7 @@ class BambaMixer(nn.Module):
|
||||
A,
|
||||
D=self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=None, # was seq_idx
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight,
|
||||
rmsnorm_eps=self.norm.variance_epsilon,
|
||||
@ -401,6 +433,7 @@ class BambaMixer(nn.Module):
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=seq_idx,
|
||||
).transpose(1, 2)
|
||||
|
||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||
@ -420,7 +453,7 @@ class BambaMixer(nn.Module):
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
seq_idx=seq_idx,
|
||||
return_final_states=True,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
@ -654,9 +687,15 @@ class BambaMixer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||
if seq_idx is not None:
|
||||
raise NotImplementedError(
|
||||
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||
)
|
||||
dtype = hidden_states.dtype
|
||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||
@ -701,7 +740,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -721,8 +760,8 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
|
||||
padding-free training and/or improve torch.compile performance.
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
@ -736,6 +775,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer):
|
||||
cache_params=past_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
self_attn_weights = None
|
||||
elif self.layer_type == "attention":
|
||||
@ -838,7 +878,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs, # NOOP kwargs, for now
|
||||
**kwargs: Unpack[BambaFlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@ -890,7 +930,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
partial(decoder_layer.__call__, **kwargs),
|
||||
hidden_states,
|
||||
layer_mask,
|
||||
position_ids,
|
||||
@ -910,6 +950,7 @@ class BambaModel(BambaPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
27
src/transformers/models/falcon_h1/__init__.py
Normal file
27
src/transformers/models/falcon_h1/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_falcon_h1 import *
|
||||
from .modeling_falcon_h1 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
283
src/transformers/models/falcon_h1/configuration_falcon_h1.py
Normal file
283
src/transformers/models/falcon_h1/configuration_falcon_h1.py
Normal file
@ -0,0 +1,283 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""FalconH1 model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FalconH1Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a
|
||||
FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf).
|
||||
The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
|
||||
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 128000):
|
||||
Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`FalconH1Model`]
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
||||
model has a output word embedding layer.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
||||
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
|
||||
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
|
||||
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
|
||||
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
|
||||
significantly.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the "end-of-sequence" token.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
Max cached sequence length for the model
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mamba_d_ssm (`int`, *optional*, defaults to 1024):
|
||||
The dimension of the SSM state space latents.
|
||||
mamba_n_heads (`int`, *optional*, defaults to 128):
|
||||
The number of mamba heads used in the v2 implementation.
|
||||
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
|
||||
Head embeddding dimension size
|
||||
mamba_n_groups (`int`, *optional*, defaults to 1):
|
||||
The number of the mamba groups used in the v2 implementation.
|
||||
mamba_d_state (`int`, *optional*, defaults to 256):
|
||||
The dimension the mamba state space latents
|
||||
mamba_d_conv (`int`, *optional*, defaults to 4):
|
||||
The size of the mamba convolution kernel
|
||||
mamba_expand (`int`, *optional*, defaults to 2):
|
||||
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
|
||||
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
||||
The chunks in which to break the sequence when doing prefill/training
|
||||
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
|
||||
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
|
||||
mamba_norm_before_gate (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use RMSNorm before the gate in the Mamba block
|
||||
mamba_rms_norm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use RMSNorm instead of LayerNorm in the Mamba block
|
||||
projectors_bias (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block
|
||||
rope_theta (`float`, *optional*, defaults to 100000.0):
|
||||
The theta value used for the RoPE embeddings.
|
||||
rope_scaling (`float`, *optional*):
|
||||
The scaling value used for the RoPE embeddings. If `None`, no scaling is applied.
|
||||
lm_head_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
The multiplier for the LM head. This is used to scale the output of the LM head.
|
||||
embedding_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
The multiplier for the embedding layer. This is used to scale the output of the embedding layer.
|
||||
mlp_multipliers (`List[float]`, *optional*):
|
||||
The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is
|
||||
the multiplier of gate layer, the second value is the multiplier of the down_proj layer.
|
||||
key_multiplier (`float`, *optional*):
|
||||
The multiplier for the key layer. This is used to scale the output of the key layer.
|
||||
attention_out_multiplier (`float`, *optional*):
|
||||
The multiplier for the attention output layer. This is used to scale the output of the attention output
|
||||
attention_in_multiplier (`float`, *optional*):
|
||||
The multiplier for the attention input layer. This is used to scale the output of the attention input layer.
|
||||
ssm_multipliers (`List[float]`, *optional*):
|
||||
The multipliers for the SSM layers. This is used to scale the output of the SSM layers.
|
||||
ssm_in_multiplier (`float`, *optional*):
|
||||
The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer.
|
||||
ssm_out_multiplier (`float`, *optional*):
|
||||
The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer.
|
||||
"""
|
||||
|
||||
model_type = "falcon_h1"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=128000,
|
||||
tie_word_embeddings=False,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=1,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
max_position_embeddings=8192,
|
||||
attention_dropout=0.0,
|
||||
mamba_d_ssm=1024,
|
||||
mamba_n_heads=128,
|
||||
mamba_d_head="auto",
|
||||
mamba_n_groups=1,
|
||||
mamba_d_state=256,
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_chunk_size=256,
|
||||
mamba_conv_bias=True,
|
||||
mamba_proj_bias=False,
|
||||
mamba_norm_before_gate=True,
|
||||
mamba_rms_norm=False,
|
||||
projectors_bias=False,
|
||||
rope_theta=100000.0,
|
||||
rope_scaling=None,
|
||||
lm_head_multiplier=1.0,
|
||||
embedding_multiplier=1.0,
|
||||
mlp_multipliers=None,
|
||||
key_multiplier=None,
|
||||
attention_out_multiplier=None,
|
||||
attention_in_multiplier=None,
|
||||
ssm_multipliers=None,
|
||||
ssm_in_multiplier=None,
|
||||
ssm_out_multiplier=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attention_bias = False
|
||||
self.mlp_bias = False
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
self.use_cache = use_cache
|
||||
self.num_logits_to_keep = num_logits_to_keep
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = None
|
||||
self.rope_scaling = rope_scaling
|
||||
self.projectors_bias = projectors_bias
|
||||
mamba_intermediate = mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
|
||||
|
||||
if mamba_intermediate % mamba_n_heads != 0:
|
||||
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
|
||||
|
||||
# for the mamba_v2, must satisfy the following
|
||||
if mamba_d_head == "auto":
|
||||
mamba_d_head = mamba_intermediate // mamba_n_heads
|
||||
|
||||
if mamba_d_head * mamba_n_heads != mamba_intermediate:
|
||||
raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")
|
||||
|
||||
self.mamba_d_ssm = mamba_d_ssm
|
||||
self.mamba_n_heads = mamba_n_heads
|
||||
self.mamba_d_head = mamba_d_head
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
self.mamba_expand = mamba_expand
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
self.mamba_conv_bias = mamba_conv_bias
|
||||
self.mamba_proj_bias = mamba_proj_bias
|
||||
|
||||
self.mamba_norm_before_gate = mamba_norm_before_gate
|
||||
self.mamba_rms_norm = mamba_rms_norm
|
||||
|
||||
self.lm_head_multiplier = lm_head_multiplier
|
||||
self.embedding_multiplier = embedding_multiplier
|
||||
|
||||
if mlp_multipliers is not None:
|
||||
self.mlp_multipliers = mlp_multipliers
|
||||
else:
|
||||
self.mlp_multipliers = [1.0, 1.0]
|
||||
|
||||
if attention_out_multiplier is not None:
|
||||
self.attention_out_multiplier = attention_out_multiplier
|
||||
else:
|
||||
self.attention_out_multiplier = 1.0
|
||||
|
||||
if attention_in_multiplier is not None:
|
||||
self.attention_in_multiplier = attention_in_multiplier
|
||||
else:
|
||||
self.attention_in_multiplier = 1.0
|
||||
|
||||
if key_multiplier is not None:
|
||||
self.key_multiplier = key_multiplier
|
||||
else:
|
||||
self.key_multiplier = 1.0
|
||||
|
||||
if ssm_multipliers is not None:
|
||||
self.ssm_multipliers = ssm_multipliers
|
||||
else:
|
||||
#
|
||||
self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
if ssm_in_multiplier is not None:
|
||||
self.ssm_in_multiplier = ssm_in_multiplier
|
||||
else:
|
||||
self.ssm_in_multiplier = 1.0
|
||||
|
||||
if ssm_out_multiplier is not None:
|
||||
self.ssm_out_multiplier = ssm_out_multiplier
|
||||
else:
|
||||
self.ssm_out_multiplier = 1.0
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
return ["attention" for i in range(self.num_hidden_layers)]
|
||||
|
||||
|
||||
__all__ = ["FalconH1Config"]
|
@ -0,0 +1,151 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 TII and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This script can be used to convert checkpoints provided in the `mamba_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, FalconH1Config, FalconH1ForCausalLM
|
||||
|
||||
|
||||
CONVERSION_MAPPING = {
|
||||
"backbone": "model",
|
||||
"embeddings": "embed_tokens",
|
||||
"mixer.": "",
|
||||
"mixer_ssm": "mamba",
|
||||
"mixer_attn": "self_attn",
|
||||
"mlp.": "feed_forward.",
|
||||
"mlp_norm": "pre_ff_layernorm",
|
||||
"ssm_proj": "mamba.in_proj",
|
||||
"attn_out_proj": "o_proj",
|
||||
".norm.": ".input_layernorm.",
|
||||
".mamba.input_layernorm.": ".mamba.norm.",
|
||||
".ssm_out_proj.": ".mamba.out_proj.",
|
||||
"norm_f": "final_layernorm",
|
||||
}
|
||||
|
||||
|
||||
def convert_falcon_h1_to_hf(input_model_path, output_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(input_model_path)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
input_model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
intermediate_size = int(model.config.expansion_factor * model.config.hidden_size)
|
||||
|
||||
if intermediate_size % 2 != 0:
|
||||
intermediate_size = intermediate_size + (intermediate_size % 2)
|
||||
|
||||
new_config = FalconH1Config(
|
||||
vocab_size=model.config.vocab_size,
|
||||
tie_word_embeddings=model.config.tie_word_embeddings,
|
||||
hidden_size=model.config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
mamba_d_state=model.config.state_size,
|
||||
num_hidden_layers=model.config.num_hidden_layers,
|
||||
mamba_use_mlp=model.config.use_mlp,
|
||||
rms_norm_eps=model.config.layer_norm_epsilon,
|
||||
pad_token_id=model.config.pad_token_id,
|
||||
eos_token_id=model.config.eos_token_id,
|
||||
mamba_expand=model.config.expand,
|
||||
mamba_d_conv=model.config.conv_kernel,
|
||||
mamba_n_groups=model.config.n_groups,
|
||||
mamba_n_heads=model.config.num_heads,
|
||||
mamba_norm_before_gate=model.config.norm_before_gate,
|
||||
mamba_rms_norm=model.config.rms_norm,
|
||||
mamba_d_ssm=model.config.d_ssm,
|
||||
attention_bias=model.config.use_bias,
|
||||
projectors_bias=model.config.use_bias,
|
||||
mamba_conv_bias=model.config.use_conv_bias,
|
||||
hidden_act=model.config.hidden_act,
|
||||
use_cache=model.config.use_cache,
|
||||
mamba_chunk_size=model.config.chunk_size,
|
||||
num_attention_heads=model.config.num_heads_mha,
|
||||
num_key_value_heads=model.config.num_key_value_heads,
|
||||
head_dim=model.config.head_dim_mha,
|
||||
lm_head_multiplier=model.config.lm_head_multiplier,
|
||||
embedding_multiplier=model.config.embedding_multiplier,
|
||||
mlp_multipliers=model.config.mlp_multipliers,
|
||||
key_multiplier=model.config.key_multiplier,
|
||||
attention_out_multiplier=model.config.attention_out_multiplier,
|
||||
attention_in_multiplier=model.config.attention_in_multiplier,
|
||||
ssm_multipliers=model.config.ssm_multipliers,
|
||||
ssm_in_multiplier=model.config.ssm_in_multiplier,
|
||||
ssm_out_multiplier=model.config.ssm_out_multiplier,
|
||||
rope_theta=model.config.rope_theta,
|
||||
)
|
||||
|
||||
old_state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
|
||||
for old_key, old_value in old_state_dict.items():
|
||||
new_key = old_key
|
||||
for conversion_key, conversion_value in CONVERSION_MAPPING.items():
|
||||
if conversion_key in old_key:
|
||||
new_key = new_key.replace(conversion_key, conversion_value)
|
||||
|
||||
if "mamba.input_layernorm" in new_key:
|
||||
new_key = new_key.replace("mamba.input_layernorm", "mamba.norm")
|
||||
|
||||
# Special processing for attention layers
|
||||
if "self_attn.attn_proj" in new_key:
|
||||
num_heads = new_config.num_attention_heads
|
||||
num_kv_heads = new_config.num_key_value_heads
|
||||
head_dim = new_config.head_dim
|
||||
q_proj, k_proj, v_proj = old_value.split(
|
||||
[
|
||||
num_heads * head_dim,
|
||||
num_kv_heads * head_dim,
|
||||
num_kv_heads * head_dim,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
new_state_dict[new_key.replace("attn_proj", "q_proj")] = q_proj
|
||||
new_state_dict[new_key.replace("attn_proj", "k_proj")] = k_proj
|
||||
new_state_dict[new_key.replace("attn_proj", "v_proj")] = v_proj
|
||||
else:
|
||||
new_state_dict[new_key] = old_value
|
||||
|
||||
with torch.device("meta"):
|
||||
new_model = FalconH1ForCausalLM(new_config)
|
||||
|
||||
del model
|
||||
|
||||
new_model.load_state_dict(new_state_dict, strict=True, assign=True)
|
||||
|
||||
new_model.save_pretrained(output_path)
|
||||
tokenizer.save_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--mamba_ssm_checkpoint_directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a directory containing the `pytorch_model.bin` mamba_ssm checkpoint file to be converted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_falcon_h1_to_hf(
|
||||
args.mamba_ssm_checkpoint_directory,
|
||||
args.output_dir,
|
||||
)
|
1692
src/transformers/models/falcon_h1/modeling_falcon_h1.py
Normal file
1692
src/transformers/models/falcon_h1/modeling_falcon_h1.py
Normal file
File diff suppressed because it is too large
Load Diff
1442
src/transformers/models/falcon_h1/modular_falcon_h1.py
Normal file
1442
src/transformers/models/falcon_h1/modular_falcon_h1.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1062,10 +1062,21 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
||||
if token_type_ids is not None and sequence_length != 1:
|
||||
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
||||
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
||||
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
# Find where a new image block starts: 1 if image and previous not image
|
||||
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||
is_image = token_type_ids == 1
|
||||
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||
|
||||
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
|
||||
same_image_mask[image_group_ids == -1] = False # remove non-image
|
||||
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
causal_mask = causal_mask.clone()
|
||||
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
||||
token_type_mask, 0.0
|
||||
image_mask, 0.0
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -781,10 +781,21 @@ class Gemma3Model(PaliGemmaModel):
|
||||
if token_type_ids is not None and sequence_length != 1:
|
||||
token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
|
||||
token_type_mask[token_type_ids == 0] = False # if text token do not change anything
|
||||
token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
# Find where a new image block starts: 1 if image and previous not image
|
||||
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||
is_image = token_type_ids == 1
|
||||
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
|
||||
|
||||
same_image_mask = image_group_ids.unsqueeze(1) == image_group_ids.unsqueeze(2)
|
||||
same_image_mask[image_group_ids == -1] = False # remove non-image
|
||||
image_mask = (token_type_mask & same_image_mask).unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
|
||||
|
||||
causal_mask = causal_mask.clone()
|
||||
causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
|
||||
token_type_mask, 0.0
|
||||
image_mask, 0.0
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
|
@ -439,6 +439,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
@ -521,7 +522,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
A,
|
||||
D=self.D,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=None, # was seq_idx
|
||||
seq_idx=seq_idx,
|
||||
activation=self.activation,
|
||||
rmsnorm_weight=self.norm.weight,
|
||||
rmsnorm_eps=self.norm.variance_epsilon,
|
||||
@ -562,6 +563,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=seq_idx,
|
||||
).transpose(1, 2)
|
||||
|
||||
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
|
||||
@ -581,7 +583,7 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
chunk_size=self.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
seq_idx=seq_idx,
|
||||
return_final_states=True,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
@ -815,9 +817,15 @@ class GraniteMoeHybridMambaLayer(nn.Module):
|
||||
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
seq_idx: Optional[torch.IntTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask, seq_idx)
|
||||
if seq_idx is not None:
|
||||
raise NotImplementedError(
|
||||
"`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
|
||||
)
|
||||
dtype = hidden_states.dtype
|
||||
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
|
||||
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
|
||||
|
@ -144,7 +144,7 @@ class Llama4TextMoe(nn.Module):
|
||||
|
||||
def forward(self, hidden_states):
|
||||
batch, seq_len, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
||||
router_logits = self.router(hidden_states)
|
||||
tokens_per_expert = batch * seq_len
|
||||
|
||||
@ -258,6 +258,33 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32
|
||||
def vision_eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Llama4TextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
@ -534,10 +561,10 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
if self.config.get_text_config().get("attention_chunk_size") is not None:
|
||||
if self.config.get_text_config().attention_chunk_size is not None:
|
||||
past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
else:
|
||||
past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
@ -730,7 +757,7 @@ class Llama4TextModel(Llama4PreTrainedModel):
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
|
||||
chunked_attention_mask = chunked_attention_mask.bool()
|
||||
causal_mask = causal_mask.bool()
|
||||
causal_mask = causal_mask != torch.finfo(dtype).min
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
@ -1099,7 +1126,7 @@ class Llama4VisionAttention(nn.Module):
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
attention_interface: Callable = vision_eager_attention_forward
|
||||
# flex disable because breaks on TP 8, embed is 88 not power of 2
|
||||
if self.config._attn_implementation not in ["eager", "flex_attention"]:
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
@ -1117,7 +1144,7 @@ class Llama4VisionAttention(nn.Module):
|
||||
value_states,
|
||||
None,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=None,
|
||||
scaling=None, # TODO Might be enforced here for TP compatibility as scaling is not just sqrt(head_dim)
|
||||
is_causal=False, # HAS TO BE ENFORCED
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -424,19 +424,23 @@ class LlavaNextImageProcessor(BaseImageProcessor):
|
||||
|
||||
return resized_image
|
||||
|
||||
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
|
||||
original_height, original_width = original_resolution
|
||||
target_height, target_width = target_resolution
|
||||
paste_x, r_x = divmod(target_width - original_width, 2)
|
||||
paste_y, r_y = divmod(target_height - original_height, 2)
|
||||
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> np.array:
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
padding = self._get_padding_size(new_resolution, target_resolution)
|
||||
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
|
||||
padded_image = self.pad(image, padding=padding)
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -141,19 +141,23 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
return resized_image
|
||||
|
||||
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
|
||||
original_height, original_width = original_resolution
|
||||
target_height, target_width = target_resolution
|
||||
paste_x, r_x = divmod(target_width - original_width, 2)
|
||||
paste_y, r_y = divmod(target_height - original_height, 2)
|
||||
return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
padding = self._get_padding_size(new_resolution, target_resolution)
|
||||
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
|
||||
padded_image = F.pad(image, padding=padding)
|
||||
|
||||
return padded_image
|
||||
|
||||
|
@ -315,6 +315,14 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
|
||||
return resized_image
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._get_padding_size
|
||||
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
|
||||
original_height, original_width = original_resolution
|
||||
target_height, target_width = target_resolution
|
||||
paste_x, r_x = divmod(target_width - original_width, 2)
|
||||
paste_y, r_y = divmod(target_height - original_height, 2)
|
||||
return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
|
||||
|
||||
# Copied from transformers.models.llava_next.image_processing_llava_next.LlavaNextImageProcessor._pad_for_patching
|
||||
def _pad_for_patching(
|
||||
self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension
|
||||
@ -322,13 +330,10 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
padding = self._get_padding_size(new_resolution, target_resolution)
|
||||
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = self.pad(image, padding=((paste_y, paste_y + r_y), (paste_x, paste_x + r_x)))
|
||||
padded_image = self.pad(image, padding=padding)
|
||||
|
||||
return padded_image
|
||||
|
||||
@ -437,6 +442,85 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
|
||||
return pixel_values
|
||||
|
||||
# Copied from transformers.models.llava.image_processing_llava.LlavaImageProcessor.pad_to_square
|
||||
def pad_to_square(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
background_color: Union[int, Tuple[int, int, int]] = 0,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
) -> np.array:
|
||||
"""
|
||||
Pads an image to a square based on the longest edge.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
The image to pad.
|
||||
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
|
||||
The color to use for the padding. Can be an integer for single channel or a
|
||||
tuple of integers representing for multi-channel images. If passed as integer
|
||||
in mutli-channel mode, it will default to `0` in subsequent channels.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use same as the input image.
|
||||
input_data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
If unset, will use the inferred format of the input image.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The padded image.
|
||||
"""
|
||||
height, width = get_image_size(image, input_data_format)
|
||||
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
|
||||
|
||||
if height == width:
|
||||
image = (
|
||||
to_channel_dimension_format(image, data_format, input_data_format)
|
||||
if data_format is not None
|
||||
else image
|
||||
)
|
||||
return image
|
||||
|
||||
max_dim = max(height, width)
|
||||
|
||||
# Ensure background_color is the correct shape
|
||||
if isinstance(background_color, int):
|
||||
background_color = [background_color]
|
||||
elif len(background_color) != num_channels:
|
||||
raise ValueError(
|
||||
f"background_color must have no more than {num_channels} elements to match the number of channels"
|
||||
)
|
||||
|
||||
if input_data_format == ChannelDimension.FIRST:
|
||||
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
|
||||
for i, color in enumerate(background_color):
|
||||
result[i, :, :] = color
|
||||
if width > height:
|
||||
start = (max_dim - height) // 2
|
||||
result[:, start : start + height, :] = image
|
||||
else:
|
||||
start = (max_dim - width) // 2
|
||||
result[:, :, start : start + width] = image
|
||||
else:
|
||||
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
|
||||
for i, color in enumerate(background_color):
|
||||
result[:, :, i] = color
|
||||
if width > height:
|
||||
start = (max_dim - height) // 2
|
||||
result[start : start + height, :, :] = image
|
||||
else:
|
||||
start = (max_dim - width) // 2
|
||||
result[:, start : start + width, :] = image
|
||||
|
||||
image = (
|
||||
to_channel_dimension_format(result, data_format, input_data_format) if data_format is not None else result
|
||||
)
|
||||
return image
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
@ -595,6 +679,17 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
do_pad = do_pad if do_pad is not None else self.do_pad
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
|
||||
# if the first element is a list, we assume that all elements are lists
|
||||
batch_num_images = [len(x) for x in images]
|
||||
elif isinstance(images, (tuple, list)):
|
||||
# treat this as a single-image case for backward compatibility
|
||||
batch_num_images = [1] * len(images)
|
||||
else:
|
||||
batch_num_images = [1]
|
||||
# only single image patching is supported
|
||||
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
|
||||
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
@ -630,25 +725,34 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
size_tuple = (
|
||||
(size["height"], size["width"])
|
||||
if "height" in size and "width" in size
|
||||
else (size["shortest_edge"], size["shortest_edge"])
|
||||
)
|
||||
|
||||
new_images = []
|
||||
image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
|
||||
for image in images:
|
||||
# convert image into a list of patches
|
||||
# we intentionally use the same data format as the input data format
|
||||
size_tuple = (
|
||||
(size["height"], size["width"])
|
||||
if "height" in size and "width" in size
|
||||
else (size["shortest_edge"], size["shortest_edge"])
|
||||
)
|
||||
image_patches = self.get_image_patches(
|
||||
image,
|
||||
image_grid_pinpoints,
|
||||
size=size_tuple,
|
||||
patch_size=size_tuple[0],
|
||||
resample=resample,
|
||||
data_format=input_data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for i, image in enumerate(images):
|
||||
if need_patching[i]:
|
||||
# convert image into a list of patches
|
||||
# we intentionally use the same data format as the input data format
|
||||
image_patches = self.get_image_patches(
|
||||
image,
|
||||
image_grid_pinpoints,
|
||||
size=size_tuple,
|
||||
patch_size=size_tuple[0],
|
||||
resample=resample,
|
||||
data_format=input_data_format,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
else:
|
||||
padded_image = self.pad_to_square(
|
||||
image=image,
|
||||
background_color=tuple(int(x * 255) for x in self.image_mean),
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
image_patches = [padded_image]
|
||||
|
||||
# preprocess patches
|
||||
pixel_values = self._preprocess(
|
||||
@ -671,7 +775,8 @@ class LlavaOnevisionImageProcessor(BaseImageProcessor):
|
||||
processed_images = self._pad_for_batching(new_images)
|
||||
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -89,6 +89,15 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
|
||||
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
|
||||
# if the first element is a list, we assume that all elements are lists
|
||||
batch_num_images = [len(x) for x in images]
|
||||
elif isinstance(images, (tuple, list)):
|
||||
# treat this as a single-image case for backward compatibility
|
||||
batch_num_images = [1] * len(images)
|
||||
else:
|
||||
batch_num_images = [1]
|
||||
kwargs["batch_num_images"] = batch_num_images
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _prepare_images_structure(
|
||||
@ -137,19 +146,23 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
return resized_image
|
||||
|
||||
def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
|
||||
original_height, original_width = original_resolution
|
||||
target_height, target_width = target_resolution
|
||||
paste_x, r_x = divmod(target_width - original_width, 2)
|
||||
paste_y, r_y = divmod(target_height - original_height, 2)
|
||||
return [paste_x, paste_y, paste_x + r_x, paste_y + r_y]
|
||||
|
||||
def _pad_for_patching(
|
||||
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pad an image to a target resolution while maintaining aspect ratio.
|
||||
"""
|
||||
target_height, target_width = target_resolution
|
||||
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
|
||||
padding = self._get_padding_size(new_resolution, target_resolution)
|
||||
|
||||
paste_x, r_x = divmod(target_width - new_width, 2)
|
||||
paste_y, r_y = divmod(target_height - new_height, 2)
|
||||
|
||||
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x + r_x, paste_y + r_y])
|
||||
padded_image = F.pad(image, padding=padding)
|
||||
|
||||
return padded_image
|
||||
|
||||
@ -234,10 +247,15 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
do_pad: bool,
|
||||
batch_num_images: List[int],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
|
||||
# only single image patching is supported
|
||||
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
|
||||
|
||||
# Determine the size tuple
|
||||
if size and size.height and size.width:
|
||||
size_tuple = (size.height, size.width)
|
||||
@ -252,14 +270,20 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
else:
|
||||
patch_size = size.shortest_edge
|
||||
|
||||
for image in images:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
image_grid_pinpoints,
|
||||
size=size_tuple,
|
||||
patch_size=patch_size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
for i, image in enumerate(images):
|
||||
if need_patching[i]:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
image_grid_pinpoints,
|
||||
size=size_tuple,
|
||||
patch_size=patch_size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
else:
|
||||
padded_image = self.pad_to_square(
|
||||
images=image, background_color=tuple(int(x * 255) for x in self.image_mean)
|
||||
)
|
||||
image_patches = [padded_image]
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
@ -289,8 +313,52 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast):
|
||||
processed_images = self._pad_for_batching(processed_images)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes}, tensor_type=return_tensors
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square
|
||||
def pad_to_square(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
background_color: Union[int, Tuple[int, int, int]] = 0,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pads an image to a square based on the longest edge.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The images to pad.
|
||||
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
|
||||
The color to use for the padding. Can be an integer for single channel or a
|
||||
tuple of integers representing for multi-channel images. If passed as integer
|
||||
in mutli-channel mode, it will default to `0` in subsequent channels.
|
||||
Returns:
|
||||
`torch.Tensor`: The padded images.
|
||||
"""
|
||||
height, width = get_image_size(images, ChannelDimension.FIRST)
|
||||
|
||||
if height == width:
|
||||
return images
|
||||
|
||||
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
|
||||
if isinstance(background_color, int):
|
||||
background_color = [background_color] + [0] * (num_channels - 1)
|
||||
elif len(background_color) != num_channels:
|
||||
raise ValueError(
|
||||
f"background_color must have no more than {num_channels} elements to match the number of channels"
|
||||
)
|
||||
|
||||
max_dim = max(height, width)
|
||||
paste_x_left = (max_dim - width) // 2
|
||||
paste_y_left = (max_dim - height) // 2
|
||||
paste_x_right = max_dim - width - paste_x_left
|
||||
paste_y_right = max_dim - height - paste_y_left
|
||||
padded_images = F.pad(
|
||||
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
|
||||
)
|
||||
|
||||
return padded_images
|
||||
|
||||
|
||||
__all__ = ["LlavaOnevisionImageProcessorFast"]
|
||||
|
@ -419,8 +419,9 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
batch_num_images: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
@ -430,34 +431,34 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`, *optional*):
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`, *optional*):
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
batch_num_images (`torch.LongTensor`, *optional*):
|
||||
Number of images in each sample.
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
# ! infer image_num_patches from image_sizes
|
||||
if batch_num_images is None:
|
||||
# treat this as a single-image case for backward compatibility
|
||||
need_patching = [True] * len(image_sizes)
|
||||
else:
|
||||
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
for imsize in image_sizes
|
||||
if should_patch
|
||||
else 1
|
||||
for imsize, should_patch in zip(image_sizes, need_patching)
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
@ -500,6 +501,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
vision_aspect_ratio: Optional[str] = None,
|
||||
batch_num_images: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
@ -520,6 +522,8 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
If `"full"`, the full vision features are used.
|
||||
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
|
||||
Aspect ratio used when processong image features. The default value is "anyres_max_9".
|
||||
batch_num_images (`torch.LongTensor`, *optional*):
|
||||
Number of images in each sample.
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -558,6 +562,7 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
image_sizes,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
batch_num_images=batch_num_images,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
@ -749,6 +754,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
vision_aspect_ratio: Optional[str] = None,
|
||||
batch_num_images: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -771,6 +777,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
If `"full"`, the full vision features are used.
|
||||
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
|
||||
Aspect ratio used when processong image features. The default value is "anyres_max_9".
|
||||
batch_num_images (`torch.LongTensor`, *optional*):
|
||||
Number of images in each sample.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
@ -832,6 +840,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
batch_num_images=batch_num_images,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -28,18 +28,59 @@ from transformers.models.llava_next_video.modeling_llava_next_video import (
|
||||
LlavaNextVideoModelOutputWithPast,
|
||||
LlavaNextVideoPreTrainedModel,
|
||||
get_anyres_image_grid_shape,
|
||||
image_size_to_num_patches,
|
||||
unpad_image,
|
||||
)
|
||||
|
||||
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
|
||||
from ...image_processing_utils import BatchFeature
|
||||
from ...image_processing_utils_fast import DefaultFastImageProcessorKwargs, group_images_by_shape, reorder_images
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
is_torchdynamo_compiling,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaOnevisionFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
image_grid_pinpoints (`List[List[int]]`, *optional*):
|
||||
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
|
||||
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
|
||||
method.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
|
||||
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
|
||||
"""
|
||||
|
||||
image_grid_pinpoints: Optional[List[List[int]]]
|
||||
do_pad: Optional[bool]
|
||||
|
||||
|
||||
class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
@ -56,6 +97,147 @@ class LlavaOnevisionImageProcessorFast(LlavaNextImageProcessorFast):
|
||||
image_grid_pinpoints = [[384, 384], [384, 768], [384, 1152], [384, 1536], [384, 1920], [384, 2304], [768, 384], [768, 768], [768, 1152], [768, 1536], [768, 1920], [768, 2304], [1152, 384], [1152, 768], [1152, 1152], [1152, 1536], [1152, 1920], [1152, 2304], [1536, 384], [1536, 768], [1536, 1152], [1536, 1536], [1536, 1920], [1536, 2304], [1920, 384], [1920, 768], [1920, 1152], [1920, 1536], [1920, 1920], [1920, 2304], [2304, 384], [2304, 768], [2304, 1152], [2304, 1536], [2304, 1920], [2304, 2304]] # fmt: skip
|
||||
model_input_names = ["pixel_values_videos"]
|
||||
|
||||
# Copied from transformers.models.llava.image_processing_llava_fast.LlavaImageProcessorFast.pad_to_square
|
||||
def pad_to_square(
|
||||
self,
|
||||
images: "torch.Tensor",
|
||||
background_color: Union[int, Tuple[int, int, int]] = 0,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Pads an image to a square based on the longest edge.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The images to pad.
|
||||
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
|
||||
The color to use for the padding. Can be an integer for single channel or a
|
||||
tuple of integers representing for multi-channel images. If passed as integer
|
||||
in mutli-channel mode, it will default to `0` in subsequent channels.
|
||||
Returns:
|
||||
`torch.Tensor`: The padded images.
|
||||
"""
|
||||
height, width = get_image_size(images, ChannelDimension.FIRST)
|
||||
|
||||
if height == width:
|
||||
return images
|
||||
|
||||
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
|
||||
if isinstance(background_color, int):
|
||||
background_color = [background_color] + [0] * (num_channels - 1)
|
||||
elif len(background_color) != num_channels:
|
||||
raise ValueError(
|
||||
f"background_color must have no more than {num_channels} elements to match the number of channels"
|
||||
)
|
||||
|
||||
max_dim = max(height, width)
|
||||
paste_x_left = (max_dim - width) // 2
|
||||
paste_y_left = (max_dim - height) // 2
|
||||
paste_x_right = max_dim - width - paste_x_left
|
||||
paste_y_right = max_dim - height - paste_y_left
|
||||
padded_images = F.pad(
|
||||
images, padding=[paste_x_left, paste_y_left, paste_x_right, paste_y_right], fill=background_color
|
||||
)
|
||||
|
||||
return padded_images
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImageProcessorKwargs]) -> BatchFeature:
|
||||
if isinstance(images, (tuple, list)) and isinstance(images[0], (tuple, list)):
|
||||
# if the first element is a list, we assume that all elements are lists
|
||||
batch_num_images = [len(x) for x in images]
|
||||
elif isinstance(images, (tuple, list)):
|
||||
# treat this as a single-image case for backward compatibility
|
||||
batch_num_images = [1] * len(images)
|
||||
else:
|
||||
batch_num_images = [1]
|
||||
kwargs["batch_num_images"] = batch_num_images
|
||||
return super().preprocess(images, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: List["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
image_grid_pinpoints: List[List[int]],
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_center_crop: bool,
|
||||
crop_size: SizeDict,
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, List[float]]],
|
||||
image_std: Optional[Union[float, List[float]]],
|
||||
do_pad: bool,
|
||||
batch_num_images: List[int],
|
||||
return_tensors: Optional[Union[str, TensorType]],
|
||||
) -> BatchFeature:
|
||||
processed_images = []
|
||||
image_sizes = []
|
||||
|
||||
# only single image patching is supported
|
||||
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
|
||||
|
||||
# Determine the size tuple
|
||||
if size and size.height and size.width:
|
||||
size_tuple = (size.height, size.width)
|
||||
else:
|
||||
size_tuple = (size.shortest_edge, size.shortest_edge)
|
||||
|
||||
# Determine the patch size
|
||||
if crop_size and crop_size.height:
|
||||
patch_size = crop_size.height
|
||||
elif size and size.height:
|
||||
patch_size = size.height
|
||||
else:
|
||||
patch_size = size.shortest_edge
|
||||
|
||||
for i, image in enumerate(images):
|
||||
if need_patching[i]:
|
||||
image_patches = self._get_image_patches(
|
||||
image,
|
||||
image_grid_pinpoints,
|
||||
size=size_tuple,
|
||||
patch_size=patch_size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
else:
|
||||
padded_image = self.pad_to_square(
|
||||
images=image, background_color=tuple(int(x * 255) for x in self.image_mean)
|
||||
)
|
||||
image_patches = [padded_image]
|
||||
|
||||
# Group images by size for batched processing
|
||||
processed_image_patches_grouped = {}
|
||||
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(image_patches)
|
||||
for shape, stacked_image_patches in grouped_image_patches.items():
|
||||
if do_resize:
|
||||
stacked_image_patches = self.resize(
|
||||
image=stacked_image_patches,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
)
|
||||
if do_center_crop:
|
||||
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
|
||||
# Fused rescale and normalize
|
||||
stacked_image_patches = self.rescale_and_normalize(
|
||||
stacked_image_patches, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_image_patches_grouped[shape] = stacked_image_patches
|
||||
processed_image_patches = reorder_images(processed_image_patches_grouped, grouped_image_patches_index)
|
||||
processed_image_patches = (
|
||||
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
|
||||
)
|
||||
processed_images.append(processed_image_patches)
|
||||
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
|
||||
|
||||
if do_pad:
|
||||
processed_images = self._pad_for_batching(processed_images)
|
||||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
||||
return BatchFeature(
|
||||
data={"pixel_values": processed_images, "image_sizes": image_sizes, "batch_num_images": batch_num_images},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class LlavaOnevisionModelOutputWithPast(LlavaNextVideoModelOutputWithPast):
|
||||
pass
|
||||
@ -154,6 +336,76 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
image_features = image_features.view(batch_frames, -1, dim)
|
||||
return image_features
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_sizes: torch.Tensor,
|
||||
vision_feature_layer: Union[int, List[int]],
|
||||
vision_feature_select_strategy: str,
|
||||
batch_num_images: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||||
Actual image size of each images (H, W).
|
||||
vision_feature_layer (`Union[int, List[int]]`):
|
||||
The index of the layer to select the vision feature. If multiple indices are provided,
|
||||
the vision feature of the corresponding indices will be concatenated to form the
|
||||
vision features.
|
||||
vision_feature_select_strategy (`str`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`
|
||||
batch_num_images (`torch.LongTensor`, *optional*):
|
||||
Number of images in each sample.
|
||||
Returns:
|
||||
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
|
||||
and are of shape `(num_patches, image_length, embed_dim)`).
|
||||
"""
|
||||
# ! infer image_num_patches from image_sizes
|
||||
if batch_num_images is None:
|
||||
# treat this as a single-image case for backward compatibility
|
||||
need_patching = [True] * len(image_sizes)
|
||||
else:
|
||||
need_patching = [n == 1 for n in batch_num_images for _ in range(n)]
|
||||
image_num_patches = [
|
||||
image_size_to_num_patches(
|
||||
image_size=imsize,
|
||||
grid_pinpoints=self.config.image_grid_pinpoints,
|
||||
patch_size=self.config.vision_config.image_size,
|
||||
)
|
||||
if should_patch
|
||||
else 1
|
||||
for imsize, should_patch in zip(image_sizes, need_patching)
|
||||
]
|
||||
if pixel_values.dim() == 5:
|
||||
# stacked if input is (batch_size, num_patches, num_channels, height, width)
|
||||
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
|
||||
pixel_values = torch.cat(_pixel_values_list, dim=0)
|
||||
elif pixel_values.dim() != 4:
|
||||
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
|
||||
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
|
||||
|
||||
image_features = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# If we have one vision feature layer, return the corresponding hidden states,
|
||||
# otherwise, select the hidden states of each feature layer and concatenate them
|
||||
if isinstance(vision_feature_layer, int):
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
else:
|
||||
hs_pool = [image_features.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
|
||||
selected_image_feature = torch.cat(hs_pool, dim=-1)
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = torch.split(image_features, image_num_patches, dim=0)
|
||||
return image_features
|
||||
|
||||
def get_video_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
@ -214,6 +466,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
vision_aspect_ratio: Optional[str] = None,
|
||||
batch_num_images: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
@ -234,6 +487,8 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
If `"full"`, the full vision features are used.
|
||||
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
|
||||
Aspect ratio used when processong image features. The default value is "anyres_max_9".
|
||||
batch_num_images (`torch.LongTensor`, *optional*):
|
||||
Number of images in each sample.
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -272,6 +527,7 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
image_sizes,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
batch_num_images=batch_num_images,
|
||||
)
|
||||
image_features, feature_lens = self.pack_image_features(
|
||||
image_features,
|
||||
@ -355,6 +611,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
|
||||
vision_feature_layer: Optional[Union[int, List[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
vision_aspect_ratio: Optional[str] = None,
|
||||
batch_num_images: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@ -377,6 +634,8 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
|
||||
If `"full"`, the full vision features are used.
|
||||
vision_aspect_ratio (`str`, *optional*, defaults to `"anyres_max_9"`):
|
||||
Aspect ratio used when processong image features. The default value is "anyres_max_9".
|
||||
batch_num_images (`torch.LongTensor`, *optional*):
|
||||
Number of images in each sample.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
@ -438,6 +697,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaNextVideoForConditionalGenerat
|
||||
vision_aspect_ratio=vision_aspect_ratio,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
batch_num_images=batch_num_images,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -170,12 +170,15 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
batch_num_images = iter(image_inputs["batch_num_images"])
|
||||
image_sizes = iter(image_inputs["image_sizes"])
|
||||
height, width = get_image_size(
|
||||
to_numpy_array(image_inputs["pixel_values"][0][0]),
|
||||
channel_dim=output_kwargs["images_kwargs"].get("data_format"),
|
||||
)
|
||||
text, num_image_tokens = self._expand_image_tokens(text, image_sizes, height, width, self.image_token)
|
||||
text, num_image_tokens = self._expand_image_tokens(
|
||||
text, image_sizes, height, width, self.image_token, batch_num_images
|
||||
)
|
||||
|
||||
if videos is not None:
|
||||
video_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
|
||||
@ -205,23 +208,29 @@ class LlavaOnevisionProcessor(ProcessorMixin):
|
||||
height: int,
|
||||
width: int,
|
||||
special_token: str,
|
||||
num_frames: int = 1,
|
||||
batch_num_images: Iterable[int],
|
||||
):
|
||||
prompt_strings = []
|
||||
max_num_vision_tokens = 0
|
||||
for sample in text:
|
||||
if special_token in sample:
|
||||
is_multi_image = next(batch_num_images) != 1
|
||||
else:
|
||||
is_multi_image = False
|
||||
while special_token in sample:
|
||||
image_size_list = next(image_sizes)
|
||||
original_size = image_size_list[0] if num_frames != 1 else image_size_list
|
||||
if not isinstance(original_size, (list, tuple)):
|
||||
# cast to list to avoid numerical precision errors when calculating unpadding
|
||||
original_size = original_size.tolist()
|
||||
orig_height, orig_width = original_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if is_multi_image:
|
||||
num_image_tokens = self.num_image_tokens + 1 # one for image_newline
|
||||
else:
|
||||
original_size = next(image_sizes)
|
||||
if not isinstance(original_size, (list, tuple)):
|
||||
# cast to list to avoid numerical precision errors when calculating unpadding
|
||||
original_size = original_size.tolist()
|
||||
orig_height, orig_width = original_size
|
||||
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
|
||||
max_num_vision_tokens = max(max_num_vision_tokens, num_image_tokens)
|
||||
if self.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
sample = sample.replace(special_token, "<placeholder>" * num_image_tokens * num_frames, 1)
|
||||
sample = sample.replace(special_token, "<placeholder>" * num_image_tokens, 1)
|
||||
prompt_strings.append(sample)
|
||||
text = [sample.replace("<placeholder>", special_token) for sample in prompt_strings]
|
||||
return text, max_num_vision_tokens
|
||||
|
@ -486,8 +486,6 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
value_states = self.v_proj(cross_attention_states)
|
||||
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
key_states = self.k_norm(key_states)
|
||||
if past_key_value is not None:
|
||||
@ -850,7 +848,7 @@ class MllamaRotaryEmbedding(nn.Module):
|
||||
@auto_docstring
|
||||
class MllamaPreTrainedModel(PreTrainedModel):
|
||||
config_class = MllamaConfig
|
||||
base_model_prefix = "model"
|
||||
base_model_prefix = ""
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"MllamaVisionEncoderLayer",
|
||||
|
@ -40,7 +40,7 @@ from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask,
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
|
||||
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
|
||||
|
||||
|
||||
@ -358,7 +358,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
@ -1659,9 +1659,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
@ -1676,9 +1676,9 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
@ -1694,20 +1694,32 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
if position_ids is None:
|
||||
attention_mask_2d = attention_mask
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
# models currently cannot do asssisted decoding
|
||||
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||
(input_ids is not None and input_ids.shape[1] != 1)
|
||||
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||
)
|
||||
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
):
|
||||
)
|
||||
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts,
|
||||
attention_mask,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
attention_mask=attention_mask_2d,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
@ -1747,6 +1759,61 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
|
||||
@ -2108,60 +2175,5 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel", "Qwen2_5_VLTextModel"]
|
||||
|
@ -50,7 +50,7 @@ from ...image_utils import ImageInput
|
||||
from ...modeling_flash_attention_utils import is_flash_attn_available
|
||||
from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
from ...utils import is_torchdynamo_compiling, logging
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
@ -647,9 +647,9 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
@ -664,9 +664,9 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
@ -682,20 +682,32 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
if position_ids is None:
|
||||
attention_mask_2d = attention_mask
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
# models currently cannot do asssisted decoding
|
||||
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||
(input_ids is not None and input_ids.shape[1] != 1)
|
||||
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||
)
|
||||
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
):
|
||||
)
|
||||
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts,
|
||||
attention_mask,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
attention_mask=attention_mask_2d,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
|
@ -924,7 +924,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.get_text_config().initializer_range
|
||||
@ -1616,16 +1616,28 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
if position_ids is None:
|
||||
attention_mask_2d = attention_mask
|
||||
if attention_mask is not None and attention_mask.ndim == 4:
|
||||
attention_mask_2d = torch.diagonal(attention_mask_2d[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
|
||||
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
# models currently cannot do asssisted decoding
|
||||
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||
(input_ids is not None and input_ids.shape[1] != 1)
|
||||
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||
)
|
||||
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
):
|
||||
)
|
||||
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask_2d
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
@ -1662,6 +1674,62 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
_checkpoint_conversion_mapping = {
|
||||
@ -1974,61 +2042,5 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel", "Qwen2VLTextModel"]
|
||||
|
@ -972,7 +972,9 @@ class Trainer:
|
||||
)
|
||||
return remove_columns_collator
|
||||
|
||||
def _get_train_sampler(self, train_dataset) -> Optional[torch.utils.data.Sampler]:
|
||||
def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
|
||||
if train_dataset is None:
|
||||
train_dataset = self.train_dataset
|
||||
if train_dataset is None or not has_length(train_dataset):
|
||||
return None
|
||||
|
||||
|
@ -36,7 +36,9 @@ BLACK_SQUARE = "■"
|
||||
WHITE_SQUARE = "⬚"
|
||||
|
||||
|
||||
def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_window=None, token_type_ids=None):
|
||||
def generate_attention_matrix_from_mask(
|
||||
words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
|
||||
):
|
||||
"""
|
||||
Generates an attention matrix from a given attention mask.
|
||||
|
||||
@ -80,6 +82,14 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
|
||||
for j in range(n)
|
||||
)
|
||||
|
||||
if token_type_ids is not None:
|
||||
is_special = token_type_ids == 1
|
||||
token_type_buckets = torch.where(
|
||||
(token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
|
||||
)
|
||||
boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
|
||||
token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
|
||||
|
||||
# Print headers
|
||||
legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
|
||||
output.append(" " + legend)
|
||||
@ -103,7 +113,6 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
|
||||
if sliding_window is not None
|
||||
else ""
|
||||
)
|
||||
|
||||
for i, word in enumerate(words):
|
||||
word_repr = repr(word).ljust(max_word_length)
|
||||
colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
|
||||
@ -121,7 +130,9 @@ def generate_attention_matrix_from_mask(words, mask, img_token="<img>", sliding_
|
||||
if sliding_window is not None:
|
||||
sliding_window_row = " ".join(
|
||||
f"{YELLOW}{BLACK_SQUARE}{RESET}"
|
||||
if img_token in words[j] and img_token in words[i]
|
||||
if img_token in words[j]
|
||||
and img_token in words[i]
|
||||
and token_type_buckets[0, i] == token_type_buckets[0, j]
|
||||
else f"{GREEN}{BLACK_SQUARE}{RESET}"
|
||||
if i == j
|
||||
else BLACK_SQUARE
|
||||
@ -170,7 +181,8 @@ class AttentionMaskVisualizer:
|
||||
if self.config.model_type in PROCESSOR_MAPPING_NAMES:
|
||||
img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
|
||||
img = Image.open(requests.get(img, stream=True).raw)
|
||||
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=5)
|
||||
image_seq_length = 5
|
||||
processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
|
||||
if hasattr(processor, "image_token"):
|
||||
image_token = processor.image_token
|
||||
else:
|
||||
@ -179,7 +191,7 @@ class AttentionMaskVisualizer:
|
||||
if image_token:
|
||||
input_sentence = input_sentence.replace("<img>", image_token)
|
||||
|
||||
inputs = processor(img, input_sentence, suffix=suffix, return_tensors="pt")
|
||||
inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
|
||||
|
||||
self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
|
||||
|
||||
@ -223,6 +235,7 @@ class AttentionMaskVisualizer:
|
||||
img_token=self.image_token,
|
||||
sliding_window=getattr(self.config, "sliding_window", None),
|
||||
token_type_ids=kwargs.get("token_type_ids", None),
|
||||
image_seq_length=image_seq_length,
|
||||
)
|
||||
print(f_string)
|
||||
print(f"{top_bottom_border}")
|
||||
|
@ -1222,6 +1222,9 @@ def is_keras_nlp_available():
|
||||
|
||||
def is_in_notebook():
|
||||
try:
|
||||
# Check if we are running inside Marimo
|
||||
if "marimo" in sys.modules:
|
||||
return True
|
||||
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||
get_ipython = sys.modules["IPython"].get_ipython
|
||||
if "IPKernelApp" not in get_ipython().config:
|
||||
|
@ -180,6 +180,7 @@ class AriaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMi
|
||||
all_model_classes = (AriaModel, AriaForConditionalGeneration) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user