Merge branch 'main' into add-aimv2-model

This commit is contained in:
Yaswanth Gali 2025-05-22 17:29:27 +05:30 committed by GitHub
commit 1a11e86dac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
313 changed files with 13213 additions and 9673 deletions

View File

@ -39,55 +39,100 @@ jobs:
name: ci_results_run_models_gpu
path: /transformers/ci_results_run_models_gpu
- name: Check file
working-directory: /transformers
run: |
if [ -f ci_results_run_models_gpu/new_model_failures.json ]; then
echo "`ci_results_run_models_gpu/new_model_failures.json` exists, continue ..."
echo "process=true" >> $GITHUB_ENV
else
echo "`ci_results_run_models_gpu/new_model_failures.json` doesn't exist, abort."
echo "process=false" >> $GITHUB_ENV
fi
- uses: actions/download-artifact@v4
if: ${{ env.process == 'true' }}
with:
pattern: setup_values*
path: setup_values
merge-multiple: true
- name: Prepare some setup values
if: ${{ env.process == 'true' }}
run: |
if [ -f setup_values/prev_workflow_run_id.txt ]; then
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
if [ -f setup_values/other_workflow_run_id.txt ]; then
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
- name: Update clone
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: git fetch && git checkout ${{ github.sha }}
- name: Get target commit
working-directory: /transformers/utils
if: ${{ env.process == 'true' }}
run: |
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"]); print(commit)')" >> $GITHUB_ENV
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"], workflow_run_id=os.environ["PREV_WORKFLOW_RUN_ID"]); print(commit)')" >> $GITHUB_ENV
- name: Checkout to `start_sha`
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: git fetch && git checkout ${{ inputs.start_sha }}
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .
- name: NVIDIA-SMI
if: ${{ env.process == 'true' }}
run: |
nvidia-smi
- name: Environment
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
python3 utils/print_env.py
- name: Show installed libraries and their versions
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: pip freeze
- name: Check failed tests
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: python3 utils/check_bad_commit.py --start_commit ${{ inputs.start_sha }} --end_commit ${{ env.END_SHA }} --file ci_results_run_models_gpu/new_model_failures.json --output_file new_model_failures_with_bad_commit.json
- name: Show results
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
ls -l new_model_failures_with_bad_commit.json
cat new_model_failures_with_bad_commit.json
- name: Checkout back
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
git checkout ${{ inputs.start_sha }}
- name: Process report
shell: bash
working-directory: /transformers
if: ${{ env.process == 'true' }}
env:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
run: |
python3 utils/process_bad_commit_report.py
@ -95,7 +140,9 @@ jobs:
- name: Process report
shell: bash
working-directory: /transformers
if: ${{ env.process == 'true' }}
env:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
run: |
{
@ -105,7 +152,7 @@ jobs:
} >> "$GITHUB_ENV"
- name: Send processed report
if: ${{ !endsWith(env.REPORT_TEXT, '{}') }}
if: ${{ env.process == 'true' && !endsWith(env.REPORT_TEXT, '{}') }}
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.

View File

@ -8,8 +8,43 @@ on:
push:
branches:
- run_scheduled_ci*
workflow_dispatch:
inputs:
prev_workflow_run_id:
description: 'previous workflow run id to compare'
type: string
required: false
default: ""
other_workflow_run_id:
description: 'other workflow run id to compare'
type: string
required: false
default: ""
# Used for `push` to easily modiffy the target workflow runs to compare against
env:
prev_workflow_run_id: ""
other_workflow_run_id: ""
jobs:
setup:
name: Setup
runs-on: ubuntu-22.04
steps:
- name: Setup
run: |
mkdir "setup_values"
echo "${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}" > "setup_values/prev_workflow_run_id.txt"
echo "${{ inputs.other_workflow_run_id || env.other_workflow_run_id }}" > "setup_values/other_workflow_run_id.txt"
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: setup_values
path: setup_values
model-ci:
name: Model CI
uses: ./.github/workflows/self-scheduled.yml

View File

@ -39,6 +39,21 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
- name: Prepare some setup values
run: |
if [ -f setup_values/prev_workflow_run_id.txt ]; then
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
if [ -f setup_values/other_workflow_run_id.txt ]; then
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
- name: Send message to Slack
if: ${{ inputs.job != 'run_quantization_torch_gpu' }}
env:
@ -50,7 +65,6 @@ jobs:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
CI_EVENT: ${{ inputs.ci_event }}
CI_SHA: ${{ github.sha }}
CI_WORKFLOW_REF: ${{ github.workflow_ref }}
CI_TEST_JOB: ${{ inputs.job }}
SETUP_STATUS: ${{ inputs.setup_status }}
# We pass `needs.setup.outputs.matrix` as the argument. A processing in `notification_service.py` to change
@ -58,7 +72,6 @@ jobs:
# For a job that doesn't depend on (i.e. `needs`) `setup`, the value for `inputs.folder_slices` would be an
# empty string, and the called script still get one argument (which is the emtpy string).
run: |
sudo apt-get install -y curl
pip install huggingface_hub
pip install slack_sdk
pip show slack_sdk
@ -86,7 +99,6 @@ jobs:
# We pass `needs.setup.outputs.quantization_matrix` as the argument. A processing in `notification_service_quantization.py` to change
# `quantization/bnb` to `quantization_bnb` is required, as the artifact names use `_` instead of `/`.
run: |
sudo apt-get install -y curl
pip install huggingface_hub
pip install slack_sdk
pip show slack_sdk

View File

@ -71,6 +71,9 @@ RUN python3 -m pip install --no-cache-dir g2p-en
# For Some bitsandbytes tests
RUN python3 -m pip install --no-cache-dir einops
# For Some tests with `@require_liger_kernel`
RUN python3 -m pip install --no-cache-dir liger-kernel
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
RUN python3 -m pip uninstall -y kernels

View File

@ -455,6 +455,8 @@
title: Falcon
- local: model_doc/falcon3
title: Falcon3
- local: model_doc/falcon_h1
title: FalconH1
- local: model_doc/falcon_mamba
title: FalconMamba
- local: model_doc/flan-t5

View File

@ -125,4 +125,44 @@ would expect from a usual Python dictionary:
# You can also globally `register` a new function directly on it
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
```
```
## Attention Mask Interface
Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
the `AttentionInterface`:
```python
from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask
import torch
def my_new_sdpa_mask(*args, **kwargs):
print("I just entered the attention mask computation")
return sdpa_mask(*args, **kwargs)
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
```
The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
and `attention_mask=None` will be passed along to the Attention layers.
The default signature of the attention mask functions is the following:
```python
def custom_attention_mask(
batch_size: int, # required arg
cache_position: torch.Tensor, # required arg
kv_length: int, # required arg
kv_offset: int = 0, # required arg
mask_function: Callable = causal_mask_function, # required arg
attention_mask: Optional[torch.Tensor] = None, # required arg
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
) -> Optional[torch.Tensor]:
```
It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.
If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).

View File

@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`].
```py
from transformers import SamModel
from transformers.models.sam import modeling_sam
# replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
# load the pretrained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")
# replace the attention class in the vision_encoder module
for layer in model.vision_encoder.layers:
if hasattr(layer, "attn"):
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
```
## LoRA
@ -138,7 +134,7 @@ config = LoraConfig(
# apply LoRA to q and v
target_modules=["q", "v"],
lora_dropout=0.1,
task_type="mask-generation"
task_type="FEATURE_EXTRACTION"
)
```
@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer
```py
model.print_trainable_parameters()
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"
```

View File

@ -29,6 +29,11 @@ Most of those are only useful if you are studying the code of the models in the
[[autodoc]] AttentionInterface
- register
## Attention Mask Functions
[[autodoc]] AttentionMaskInterface
- register
## Rotary Position Embedding Functions
[[autodoc]] dynamic_rope_update

View File

@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
<!---
## Usage Tips
Tips:
Tips:
- The architecture is based on Mamba-2 models.
@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
```
## Padding-Free Training
Bamba supports padding-free training in which distinct training examples can be concatenated
together while nevertheless processing the inputs as though they belonged to separate batches. When
the examples are of varying lengths, padding-free training can provide significant speed ups and
memory savings compared to batching the examples together and using padding, as the unnecessary
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
as the model and the data distribution, but throughput gains up to [~2x are commonly
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
packages, and the following arguments must be passed to the model in addition to `input_ids` and
`labels`:
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
* Each of the [`FlashAttentionKwargs`]
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
* `max_length_q: int`: the longest query length in the batch.
* `max_length_k: int`: the longest key length in the batch.
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
for additional information.
[[autodoc]] BambaForCausalLM
- forward
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).

View File

@ -0,0 +1,65 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# FalconH1
## Overview
The FalconH1 model was developed by the TII Pretraining team. A comprehensive research paper covering the architecture, pretraining dynamics, experimental results, and conclusions is forthcoming. You can read more about this series in [this website](https://github.com/tiiuae/Falcon-H1).
## Contributors
This model was contributed by [DhiyaEddine](https://huggingface.co/DhiyaEddine), [ybelkada](https://huggingface.co/ybelkada), [JingweiZuo](https://huggingface.co/JingweiZuo), [IlyasChahed](https://huggingface.co/IChahed), and [MaksimVelikanov](https://huggingface.co/yellowvm).
The original code can be found [here](https://github.com/tiiuae/Falcon-H1).
## FalconH1Config
| Model | Depth | Dim | Attn Heads | KV | Mamba Heads | d_head | d_state | Ctx Len |
|-----------|--------|------|------------|----|--------------|--------------|------|-----------------|
| H1 0.5B | 36 | 1024 | 8 | 2 | 24 | 64 / 64 | 128 | 4K, 16K-SFT |
| H1 1.5B | 24 | 2048 | 8 | 2 | 48 | 128 / 64 | 256 | 128K |
| H1 1.5B-d | 66 | 1280 | 6 | 2 | 24 | 128 / 64 | 256 | 128K |
| H1 3B | 32 | 2560 | 10 | 2 | 32 | 128 / 128 | 256 | 128K |
| H1 7B | 44 | 3072 | 12 | 2 | 24 | 128 / 128 | 256 | 256K |
| H1 34B | 72 | 5120 | 20 | 4 | 32 | 128 / 128 | 256 | 256K |
[[autodoc]] FalconH1Config
<!---
## Usage Tips
Tips:
- The architecture is based on Mamba-2 models.
## FalconH1Model
[[autodoc]] FalconH1Model
- forward
-->
## FalconH1ForCausalLM
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
message = ["Mamba is a snake with following properties "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
```
[[autodoc]] FalconH1ForCausalLM
- forward
This HF implementation is contributed by [younesbelkada](https://github.com/younesbelkada) and [DhiaEddineRhaiem](https://github.com/dhiaEddineRhaiem).

View File

@ -147,7 +147,7 @@ print(processor.decode(output[0], skip_special_tokens=True))
### Multi image inference
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. Here is how you can do it:
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. For multi-image cases, we recommend using a **nested list of images** as input. Otherwise, every image will be patchified and consume a lot of memory. Here is how you can do it:
```python
import requests

View File

@ -14,85 +14,124 @@ rendered properly in your Markdown viewer.
-->
# Mamba
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>
## Overview
# Mamba
The Mamba model was proposed in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
[Mamba](https://huggingface.co/papers/2312.00752) is a selective structured state space model (SSMs) designed to work around Transformers computational inefficiency when dealing with long sequences. It is a completely attention-free architecture, and comprised of a combination of H3 and gated MLP blocks (Mamba block). Mamba's "content-based reasoning" allows it to focus on specific parts of an input depending on the current token. Mamba also uses a new hardware-aware parallel algorithm to compensate for the lack of convolutional operations. As a result, Mamba has fast inference and can scale to very long sequences.
This model is a new paradigm architecture based on `state-space-models`. You can read more about the intuition behind these [here](https://srush.github.io/annotated-s4/).
You can find all the original Mamba checkpoints under the [State Space Models](https://huggingface.co/state-spaces) organization.
The abstract from the paper is the following:
*Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers' computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5× higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.*
> [!TIP]
> Click on the Mamba models in the right sidebar for more examples of how to apply Mamba to different language tasks.
Tips:
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
- Mamba is a new `state space model` architecture that rivals the classic Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
- Mamba stacks `mixer` layers, which are the equivalent of `Attention` layers. The core logic of `mamba` is held in the `MambaMixer` class.
- Two implementations cohabit: one is optimized and uses fast cuda kernels, while the other one is naive but can run on any device!
- The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the [`mamba-ssm`](https://github.com/state-spaces/mamba) and the [`causal_conv1d`](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports them!
- Contributions to make the naive path faster are welcome 🤗
<hfoptions id="usage">
<hfoption id="Pipeline">
This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ).
The original code can be found [here](https://github.com/state-spaces/mamba).
# Usage
### A simple generation example:
```python
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
```py
import torch
from transformers import pipeline
pipeline = pipeline(
task="text-generation",
model="state-spaces/mamba-130m-hf",
torch_dtype=torch.float16,
device=0
)
pipeline("Plants create energy through a process known as")
```
</hfoption>
<hfoption id="AutoModel">
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16, device_map="auto",)
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
output = model.generate(**input_ids)
print(tokenizer.decode(output[0], skip_special_tokens=True)
```
### Peft finetuning
The slow version is not very stable for training, and the fast one needs `float32`!
</hfoption>
<hfoption id="transformers CLI">
```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
```bash
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model state-spaces/mamba-130m-hf --device 0
```
</hfoption>
</hfoptions>
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
The example below uses [torchao](../quantization/torchao) to only quantize the weights to 4-bit integers.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from torchao.quantization import Int4WeightOnlyConfig
quantization_config = Int4WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf", torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto",)
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
output = model.generate(**input_ids)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
## Notes
- The current implementation uses the original CUDA kernels. The FlashAttention equivalent implementation is hosted in the [mamba-ssm](https://github.com/state-spaces/mamba) and [causal_conv1d](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports it!
- Mamba stacks `mixer` layers which are equivalent to `Attention` layers. You can find the main logic of Mamba in the `MambaMixer` class.
- The example below demonstrates how to fine-tune Mamba with [PEFT](https://huggingface.co/docs/peft).
```py
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
```
## MambaConfig
[[autodoc]] MambaConfig

View File

@ -43,8 +43,8 @@ import requests
from transformers import SamHQModel, SamHQProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
@ -69,8 +69,8 @@ import requests
from transformers import SamHQModel, SamHQProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

View File

@ -14,59 +14,77 @@ rendered properly in your Markdown viewer.
-->
# Swin Transformer
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
</div>
</div>
## Overview
# Swin Transformer
The Swin Transformer was proposed in [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
[Swin Transformer](https://huggingface.co/papers/2103.14030) is a hierarchical vision transformer. Images are processed in patches and windowed self-attention is used to capture local information. These windows are shifted across the image to allow for cross-window connections, capturing global information more efficiently. This hierarchical approach with shifted windows allows the Swin Transformer to process images effectively at different scales and achieve linear computational complexity relative to image size, making it a versatile backbone for various vision tasks like image classification and object detection.
The abstract from the paper is the following:
You can find all official Swin Transformer checkpoints under the [Microsoft](https://huggingface.co/microsoft?search_models=swin) organization.
*This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone
for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains,
such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text.
To address these differences, we propose a hierarchical Transformer whose representation is computed with \bold{S}hifted
\bold{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping
local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at
various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it
compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense
prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation
(53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and
+2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones.
The hierarchical design and the shifted window approach also prove beneficial for all-MLP architectures.*
> [!TIP]
> Click on the Swin Transformer models in the right sidebar for more examples of how to apply Swin Transformer to different image tasks.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
alt="drawing" width="600"/>
The example below demonstrates how to classify an image with [`Pipeline`] or the [`AutoModel`] class.
<small> Swin Transformer architecture. Taken from the <a href="https://arxiv.org/abs/2102.03334">original paper</a>.</small>
<hfoptions id="usage">
<hfoption id="Pipeline">
This model was contributed by [novice03](https://huggingface.co/novice03). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/microsoft/Swin-Transformer).
```py
import torch
from transformers import pipeline
## Usage tips
pipeline = pipeline(
task="image-classification",
model="microsoft/swin-tiny-patch4-window7-224",
torch_dtype=torch.float16,
device=0
)
pipeline(images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg")
```
</hfoption>
- Swin pads the inputs supporting any input height and width (if divisible by `32`).
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
<hfoption id="AutoModel">
## Resources
```py
import torch
import requests
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Swin Transformer.
image_processor = AutoImageProcessor.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224",
use_fast=True,
)
model = AutoModelForImageClassification.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224",
device_map="cuda"
)
<PipelineTag pipeline="image-classification"/>
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = image_processor(image, return_tensors="pt").to("cuda")
- [`SwinForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
- See also: [Image classification task guide](../tasks/image_classification)
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax(dim=-1).item()
Besides that:
class_labels = model.config.id2label
predicted_class_label = class_labels[predicted_class_id]
print(f"The predicted class label is: {predicted_class_label}")
```
</hfoption>
</hfoptions>
- [`SwinForMaskedImageModeling`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
## Notes
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
- Swin can pad the inputs for any input height and width divisible by `32`.
- Swin can be used as a [backbone](../backbones). When `output_hidden_states = True`, it outputs both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
## SwinConfig

View File

@ -95,7 +95,7 @@ transcription[0]
## Notes
- Whisper relies on [`~GenerationMixin.generate`] for inference.
- Whisper relies a custom [`generate`] for inference, make sure to check the docs below.
- The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text.
## WhisperConfig

View File

@ -54,8 +54,8 @@ For each model type, there is a separate class for each machine learning framewo
from transformers import AutoModelForCausalLM, MistralForCausalLM
# load with AutoClass or model-specific class
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
```
</hfoption>
@ -272,6 +272,7 @@ Explicitly set the [torch_dtype](https://pytorch.org/docs/stable/tensor_attribut
<hfoption id="specific dtype">
```py
import torch
from transformers import AutoModelForCausalLM
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)

View File

@ -13,9 +13,15 @@ rendered properly in your Markdown viewer.
-->
# Distributed GPU inference
# Tensor parallelism in transformers
[Tensor parallelism](./perf_train_gpu_many#tensor-parallelism) shards a model onto multiple GPUs and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each GPU can process a tensor slice.
This document assumes that you are already familiar with the basics of tensor parallelism. If you are not, please refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) section on tensor parallelism.
> [!TIP]
> Tensor parallelism is very communication intensive, therefore it is reccomended to use it on a single machine with multiple GPUs, utilizing fast intra-node communication. For multi-node training, methods as pipeline or data parallelism are more efficient (depending on your use case).
Tensor parallelism requires slight changes to the model parameters, therefore in transformers, we support some of the popular models out of the box.
> [!TIP]
> Expand the list below to see which models support tensor parallelism. Open a GitHub issue or pull request to add support for a model not currently below.
@ -37,9 +43,218 @@ rendered properly in your Markdown viewer.
</details>
Set `tp_plan="auto"` in [`~AutoModel.from_pretrained`] to enable tensor parallelism for inference.
## Using 🤗 transformers
```py
Transformers provides a simple interface to use for tensor parallelism. We provide multiple classes implementing different partitioning
strategies and a simple entrypoint to parallelize `nn.Module` instance. You won't have to interact with this interface directly, everything is done in `PretrainedModel.from_pretrained` method for you. This section will first talk about the partitioning strategies
we support, then the user interface you will be interacting with, and finally it will teach you how to extend it with your own partitioning
strategies.
### Partitioning strategies
In transformers, partitioning strategies reside in a class `ParallelInterface` which works like a mapping from string to the strategy implementation.
```python
class ParallelInterface(MutableMapping):
"""
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
"""
_global_mapping = {
"colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(),
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
"local_colwise": ColwiseParallel(use_dtensor=False),
"local_rowwise": RowwiseParallel(use_dtensor=False),
"local": IsolatedParallel(),
"gather": GatherParallel(),
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
"sequence_parallel": SequenceParallel(),
"replicate": ReplicateParallel(),
}
```
We support the following strategies:
- `ColwiseParallel` - A simple column-wise partitioning, being able to handle both weights and biases, does exactly what we've discussed before.
- `RowwiseParallel` - Again, row-wise partitioning as dicussed before, supports weights and biases, on top of that it also supports `nn.Embedding` modules.
- `SequenceParallel` - Sequence parallel implementation, for support of `LayerNorm` and `Dropout` layers. Also supports Python implementation of `RMSNorm` (see [this](https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34))
- `PackedColwiseParallel` - A variant of column-wise partitioning, however it works on packed weights (i.e. `up_proj` and `gate_proj` being packed together). For more details, see [this comment](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py#L79-#L108)
- `PackedRowwiseParallel` - A variant of row-wise partitioning, works on packed weights, for more details check the comment linked above.
- `GatherParallel` - A very simple class, that only makes the outputs of the module to be gathered across devices.
- `IsolatedParallel` - This is a special case, where we want to *isolate* the module from the rest of the devices (world). This is used for Experts in MoE layers, basically creating Expert parallelism of sorts.
- `ReplicateParallel` - Many `torch.distributed` APIs break if model is partially sharded, so this class is used to replicate the module across all devices.
### Sharding a model
We provide two ways to shard a model, first one is to use `auto` tensor parallelism plan, which will automatically shard the model based on our predefined configuration. This requires the model to have predefined tensor parallel plan in transformers.
```python
from transformers import AutoModelForCausalLM
# model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # better for smaller number of GPUs
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # better to visualize all the possible strategies
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan="auto")
print(model._tp_plan)
```
> [!TIP]
> For a list of models that support tensor parallelism, see the [Supported models](#supported-models) section above.
The second way is to manually specify your own partitioning plan.
```python
from transformers import AutoModelForCausalLM
tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise",
"model.layers.*.self_attn.k_proj": "colwise",
"model.layers.*.self_attn.v_proj": "colwise",
"model.layers.*.self_attn.o_proj": "rowwise",
...
}
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)
print(model._tp_plan)
```
You might have noticed that there are some special cases in the `ParallelInterface` mapping, let's now talk about them. This will help you understand their purpose and help with extending to other strategies.
### PackedRowwiseParallel
This class is a special case of `RowwiseParallel`, it's used to shard packed weights. Weight packing is a common technique used in models. It's a technique where we pack multiple linear layers into a single, bigger one.
For example in `Llama4` model, we pack `up_proj` and `gate_proj` into a single `gate_up_proj` module.
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```
Then in forward, we can use batch matrix multiplication to compute the output of the `gate_up_proj` module.
```python
def forward(self, hidden_states):
...
gate_up = torch.bmm(hidden_states, self.gate_up_proj) # Compute the output of the gate_up_proj module
gate, up = gate_up.chunk(2, dim=-1) # Split the output into gate and up
```
In this case, we need to use the `PackedRowwiseParallel` strategy to shard the `gate_up_proj` module, as using a simple `RowwiseParallel` will shard the layers wrongly.
> [!TIP]
> If this is a bit difficult to wrap your head around, check out [this comment](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py#L79-#L108) for an amazing visual representation of why `Packed*` needs to be used.
### `local*` strategies
You could have noticed that there are `local*` strategies, which use the same layers as `*` strategy, but don't use `DTensor` at all.
This is because `DTensor` is not supported for some of the operations: such as `torch.chunk`. Therefore, sometimes we need to use the `local*` strategies, which use vanilla `torch.Tensor` and do some of the distributed logic manually.
<!---
Readd this when I get the exact error message
> [!TIP]
> If you are using a custom partitioning strategy, and it's not working with `... is not supported` error, try using the `local*` strategies to see if they work better.
-->
> [!WARNING]
> Manually specifying your own partitiong plan requires a good understanding of the model architecture and how the partitioning strategies interact together. If you are not sure about this, the resulting model can be very slow, even failing or incorrect. Again, refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) which can teach you everything required.
### Extending the interface with your own partitioning strategies
This is a very advanced topic, which requires a good understanding of distributed collectives and the model architecture.
Your custom partitioning strategy should inherit from `TensorParallelLayer` defined in [integrations/tensor_parallel.py](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py) and implement: `partition_tensor`, `_prepare_input_fn` and `_prepare_output_fn`. Then it should be registered in the `ParallelInterface` mapping, so our dispatching logic can find it when specified in the `tp_plan`.
Let's go through this workflow step by step, on an already existing example: `ColwiseParallel`.
1. Inherit from `TensorParallelLayer` and initialization
```python
class ColwiseParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Optional[Placement] = None, # The input layout coming from the previous layer
output_layouts: Optional[Placement] = None, # The output layout we want to achieve
use_local_output: bool = True, # Whether to use local output or not
use_dtensor=True, # Whether to use DTensor or not
):
self.input_layouts = (input_layouts or Replicate(),) # The input sharding coming from the previous layer
self.output_layouts = (output_layouts or Shard(-1),) # Desired output sharding
self.desired_input_layouts = (Replicate(),) # Desired input sharding, inputs should be replicated across GPUs
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
```
In the `__init__` method, we define these attributes, where `input_layouts` and `output_layouts` describing, how the input and output tensors should be placed on the devices. `desired_input_layouts` is used to specify, how the input *SHOULD* be placed on the devices.
2a. Implement `partition_tensor` method
```python
def partition_tensor(
self,
param, # Full tensor of the parameter
empty_param, # Empty tensor of the parameter, will be filled with the partitioned tensor
param_type, # Type of the parameter, `bias` or `weight`
param_casting_dtype, # The type to cast the parameter to
to_contiguous, # Whether to convert the tensor to a contiguous memory layout
rank, # The rank of the current device
device_mesh, # The device mesh
) -> nn.Parameter: # Return the partitioned parameter
...
```
This method is used to partition the tensor, and fill the `empty_param` with the partitioned tensor.
We provide some utility functions to help you with this, such as `get_tensor_shard` which will get you the correct shard of the original parameter for this rank or `get_packed_weights` to help with packed weights.
2b. Implement `_prepare_input_fn` and `_prepare_output_fn` methods
These methods are used as [`pre-forward`](https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_pre_hook.html) and [`forward`](https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html) hooks respectively. Their purpose is to re-distribute the inputs and outputs to the desired layout, passed in the `__init__` method.
```python
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
...
# Do some custom logic, cast to DTensor etc.
...
return inputs.redistribute(placements=desired_input_layouts, device_mesh=device_mesh)
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
...
# Do some custom logic, cast to DTensor etc.
...
return outputs.redistribute(placements=output_layouts, device_mesh=device_mesh)
```
3. Register the strategy
Congratulations! You've implemented your own partitioning strategy. Now, to use it with your own `tp_plan`, you need to register it in the `ParallelInterface` mapping.
```python
from transformers.integrations.tensor_parallel import ParallelInterface
ParallelInterface.register_strategy("colwise_custom", ColwiseParallel)
```
And now you can use it in your `tp_plan` as such:
```python
tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise_custom",
...
}
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)
```
## Full example
Let's go through a full example of inference with tensor parallelism.
```python
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -66,17 +281,49 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
torchrun --nproc-per-node 4 demo.py
```
For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
```bash
export OMP_NUM_THREADS=56
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
```
The CPU benchmark data will be released soon.
You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.
For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct%2C%20seqlen%20%3D%20512%2C%20python%2C%20w_%20compile.png">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct%2C%20seqlen%20%3D%20512%2C%20python%2C%20w_%20compile.png">
</div>
## Tensor parallelism in-depth
Our implementation of tensor parallelism is framework-agnostic in design, but the specific implementations we've developed rely on the torch.distributed package. We heavily utilize abstractions such as `DeviceMesh` or `DTensor` to provide a simple and extensible interface to the user.
### DeviceMesh
Imagine `DeviceMesh` as a multi-dimensional grid of devices that communicate together. Different parallelization strategies require different types of communication patterns, therefore we can create a `DeviceMesh` with multiple submeshes:
```python
from torch.distributed.device_mesh import init_device_mesh
# Create a 1D mesh of 4 GPUs
device_mesh = init_device_mesh("cuda", (4,), mesh_dim_names=["tp"])
```
Then, most of the `torch.distributed` defined parallelization strategies can be applied to a mesh itself, or its submesh, automatically handling the communication patterns.
### DTensor
Abbreviation for Distributed Tensor, `DTensor` is a tensor subclass that handles the distributed logic on-top of the usual tensor operations. Most of the model weights in case of tensor parallelism are stored as `DTensor`s (with some exceptions, more on that later).
The most important part of DTensor, that is crucial to understand, is the `placement` attribute. It's an attribute that tells PyTorch how is the tensor placed on the devices of the `DeviceMesh`.
It can have the following values:
- `Shard(dimension)` - Annotates that this `DTensor` is sharded across a given dimension, over the `DeviceMesh` it was constructed under. For example, if we would like to shard weights for column-wise partitioning, we would do:
```python
weight = ...
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(0)]) # Shard across the 1st (column-wise) dimension
bias = ...
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Shard(-1)]) # Shard across the ONLY dimension
```
To give another example, for row-wise partitioning, we would do:
```python
weight = ...
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(1)]) # Shard across the 2nd (row-wise) dimension
bias = ...
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
```
- `Replicate()` - Annotates that this `DTensor` is replicated across the `DeviceMesh`. Very straight-forward, only creates a full copy of the tensor on each device.
- `Partial()` - This placement is mostly of no interest to us, it's used to annotate that this tensor is pending a reduction operation.

View File

@ -106,6 +106,8 @@ dataset[0]["text"]
Remember to resample the sampling rate to match the pretrained models required sampling rate.
```py
from datasets import Audio
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
```

View File

@ -29,8 +29,6 @@
- sections:
- isExpanded: false
sections:
- local: tasks/sequence_classification
title: テキストの分類
- local: tasks/token_classification
title: トークンの分類
- local: tasks/question_answering

View File

@ -47,7 +47,7 @@ ALBERTモデルは、「[ALBERT: A Lite BERT for Self-supervised Learning of Lan
## 参考資料
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [トークン分類タスクガイド](../tasks/token_classification)
- [質問応答タスクガイド](../tasks/question_answering)
- [マスクされた言語モデルタスクガイド](../tasks/masked_language_modeling)

View File

@ -129,7 +129,7 @@ BART を始めるのに役立つ公式 Hugging Face およびコミュニティ
- [翻訳タスクガイド](../tasks/translation)
以下も参照してください。
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [質問回答タスク ガイド](../tasks/question_answering)
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
- [抽出されたチェックポイント](https://huggingface.co/models?search=distilbart) は、この [論文](https://arxiv.org/abs/2010.13002) で説明されています。

View File

@ -76,7 +76,7 @@ BERT を始めるのに役立つ公式 Hugging Face およびコミュニティ
- [`BertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb)。
- [`TFBertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb)。
- [`FlaxBertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/flax/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification_flax.ipynb)。
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
<PipelineTag pipeline="token-classification"/>

View File

@ -58,7 +58,7 @@ BigBird は、質問応答や要約などのさまざまな NLP タスクのパ
## ドキュメント リソース
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [トークン分類タスクガイド](../tasks/token_classification)
- [質問回答タスク ガイド](../tasks/question_answering)
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)

View File

@ -58,7 +58,7 @@ BigBird は、質問応答や要約などのさまざまな NLP タスクのパ
## ドキュメント リソース
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [質問回答タスク ガイド](../tasks/question_answering)
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
- [翻訳タスクガイド](../tasks/translation)

View File

@ -39,7 +39,7 @@ BLOOM を使い始めるのに役立つ公式 Hugging Face およびコミュニ
以下も参照してください。
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [トークン分類タスクガイド](../tasks/token_classification)
- [質問回答タスク ガイド](../tasks/question_answering)

View File

@ -46,7 +46,7 @@ Bi-direction Encoders for Transformers (BERT) のフランス語版である Cam
## Resources
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [トークン分類タスクガイド](../tasks/token_classification)
- [質問回答タスク ガイド](../tasks/question_answering)
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)

View File

@ -98,7 +98,7 @@ CANINE は生の文字で動作するため、**トークナイザーなし**で
## Resources
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [トークン分類タスクガイド](../tasks/token_classification)
- [質問回答タスク ガイド](../tasks/question_answering)
- [多肢選択タスク ガイド](../tasks/multiple_choice)

View File

@ -53,7 +53,7 @@ ConvBERT トレーニングのヒントは BERT のヒントと似ています
## Resources
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [トークン分類タスクガイド](../tasks/token_classification)
- [質問回答タスク ガイド](../tasks/question_answering)
- [マスクされた言語モデリング タスク ガイド](../tasks/masked_lang_modeling)

View File

@ -61,7 +61,7 @@ CTRL モデルは、Nitish Shirish Keskar*、Bryan McCann*、Lav R. Varshney、C
## Resources
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
## CTRLConfig

View File

@ -58,7 +58,7 @@ Data2Vec の使用を開始するのに役立つ公式 Hugging Face およびコ
- カスタム データセットで [`TFData2VecVisionForImageClassification`] を微調整するには、[このノートブック](https://colab.research.google.com/github/sayakpaul/TF-2.0-Hacks/blob/master/data2vec_vision_image_classification.ipynb) を参照してください。 )。
**Data2VecText ドキュメント リソース**
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [トークン分類タスクガイド](../tasks/token_classification)
- [質問回答タスク ガイド](../tasks/question_answering)
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)

View File

@ -61,7 +61,7 @@ v2 の新機能:
[kamalkraj](https://huggingface.co/kamalkraj) による投稿。元のコードは [こちら](https://github.com/microsoft/DeBERTa) にあります。
## Resources
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
- [トークン分類タスクガイド](../tasks/token_classification)
- [質問回答タスク ガイド](../tasks/question_answering)
- [マスク言語モデリング タスク ガイド](../tasks/masked_language_modeling)

View File

@ -52,7 +52,7 @@ DeBERTa を使い始めるのに役立つ公式 Hugging Face およびコミュ
- DeBERTa による [機械学習によるスーパーチャージされた顧客サービス](https://huggingface.co/blog/supercharge-customer-service-with-machine-learning) に関するブログ投稿。
- [`DebertaForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb)。
- [`TFDebertaForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb)。
- [テキスト分類タスクガイド](../tasks/sequence_classification)
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
<PipelineTag pipeline="token-classification" />

View File

@ -1,604 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Sequence classification
[[open-in-colab]]
<Youtube id="dKE8SIt9C-w"/>
セマンティック セグメンテーションでは、画像の個々のピクセルにラベルまたはクラスを割り当てます。セグメンテーションにはいくつかのタイプがありますが、セマンティック セグメンテーションの場合、同じオブジェクトの一意のインスタンス間の区別は行われません。両方のオブジェクトに同じラベルが付けられます (たとえば、「car-1」と「car-2」の代わりに「car」)。セマンティック セグメンテーションの一般的な現実世界のアプリケーションには、歩行者や重要な交通情報を識別するための自動運転車のトレーニング、医療画像内の細胞と異常の識別、衛星画像からの環境変化の監視などが含まれます。
このガイドでは、次の方法を説明します。
1. [SceneParse150](https://huggingface.co/datasets/scene_parse_150) データセットの [SegFormer](https://huggingface.co/docs/transformers/main/en/model_doc/segformer#segformer) を微調整します。
2. 微調整したモデルを推論に使用します。
<Tip>
このタスクと互換性のあるすべてのアーキテクチャとチェックポイントを確認するには、[タスクページ](https://huggingface.co/tasks/text-classification) を確認することをお勧めします。
</Tip>
始める前に、必要なライブラリがすべてインストールされていることを確認してください。
```bash
pip install -q datasets transformers evaluate
```
モデルをアップロードしてコミュニティと共有できるように、Hugging Face アカウントにログインすることをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。
```py
>>> from huggingface_hub import notebook_login
>>> notebook_login()
```
## Load SceneParse150 dataset
まず、SceneParse150 データセットの小さいサブセットを 🤗 データセット ライブラリから読み込みます。これにより、完全なデータセットのトレーニングにさらに時間を費やす前に、実験してすべてが機能することを確認する機会が得られます。
```py
>>> from datasets import load_dataset
>>> ds = load_dataset("scene_parse_150", split="train[:50]")
```
[`~datasets.Dataset.train_test_split`] メソッドを使用して、データセットの `train` 分割をトレイン セットとテスト セットに分割します。
```py
>>> ds = ds.train_test_split(test_size=0.2)
>>> train_ds = ds["train"]
>>> test_ds = ds["test"]
```
次に、例を見てみましょう。
```py
>>> train_ds[0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x683 at 0x7F9B0C201F90>,
'annotation': <PIL.PngImagePlugin.PngImageFile image mode=L size=512x683 at 0x7F9B0C201DD0>,
'scene_category': 368}
```
- `image`: シーンの PIL イメージ。
- `annotation`: セグメンテーション マップの PIL イメージ。モデルのターゲットでもあります。
- `scene_category`: 「キッチン」や「オフィス」などの画像シーンを説明するカテゴリ ID。このガイドでは、「image」と「annotation」のみが必要になります。どちらも PIL イメージです。
また、ラベル ID をラベル クラスにマップする辞書を作成することもできます。これは、後でモデルを設定するときに役立ちます。ハブからマッピングをダウンロードし、`id2label` および `label2id` ディクショナリを作成します。
```py
>>> import json
>>> from pathlib import Path
>>> from huggingface_hub import hf_hub_download
>>> repo_id = "huggingface/label-files"
>>> filename = "ade20k-id2label.json"
>>> id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}
>>> num_labels = len(id2label)
```
## Preprocess
次のステップでは、SegFormer 画像プロセッサをロードして、モデルの画像と注釈を準備します。このデータセットのような一部のデータセットは、バックグラウンド クラスとしてゼロインデックスを使用します。ただし、実際には背景クラスは 150 個のクラスに含まれていないため、`do_reduce_labels=True`を設定してすべてのラベルから 1 つを引く必要があります。ゼロインデックスは `255` に置き換えられるため、SegFormer の損失関数によって無視されます。
```py
>>> from transformers import AutoImageProcessor
>>> checkpoint = "nvidia/mit-b0"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
```
<frameworkcontent>
<pt>
モデルを過学習に対してより堅牢にするために、画像データセットにいくつかのデータ拡張を適用するのが一般的です。このガイドでは、[torchvision](https://pytorch.org) の [`ColorJitter`](https://pytorch.org/vision/stable/generated/torchvision.transforms.ColorJitter.html) 関数を使用します。 /vision/stable/index.html) を使用して画像の色のプロパティをランダムに変更しますが、任意の画像ライブラリを使用することもできます。
```py
>>> from torchvision.transforms import ColorJitter
>>> jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
```
次に、モデルの画像と注釈を準備するための 2 つの前処理関数を作成します。これらの関数は、画像を`pixel_values`に変換し、注釈を`labels`に変換します。トレーニング セットの場合、画像を画像プロセッサに提供する前に`jitter`が適用されます。テスト セットの場合、テスト中にデータ拡張が適用されないため、画像プロセッサは`images`を切り取って正規化し、`labels` のみを切り取ります。
```py
>>> def train_transforms(example_batch):
... images = [jitter(x) for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = image_processor(images, labels)
... return inputs
>>> def val_transforms(example_batch):
... images = [x for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = image_processor(images, labels)
... return inputs
```
データセット全体に`jitter`を適用するには、🤗 Datasets [`~datasets.Dataset.set_transform`] 関数を使用します。変換はオンザフライで適用されるため、高速で消費するディスク容量が少なくなります。
```py
>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)
```
</pt>
</frameworkcontent>
<frameworkcontent>
<tf>
モデルを過学習に対してより堅牢にするために、画像データセットにいくつかのデータ拡張を適用するのが一般的です。
このガイドでは、[`tf.image`](https://www.tensorflow.org/api_docs/python/tf/image) を使用して画像の色のプロパティをランダムに変更しますが、任意のプロパティを使用することもできます。画像
好きな図書館。
2 つの別々の変換関数を定義します。
- 画像拡張を含むトレーニング データ変換
- 🤗 Transformers のコンピューター ビジョン モデルはチャネル優先のレイアウトを想定しているため、画像を転置するだけの検証データ変換
```py
>>> import tensorflow as tf
>>> def aug_transforms(image):
... image = tf.keras.utils.img_to_array(image)
... image = tf.image.random_brightness(image, 0.25)
... image = tf.image.random_contrast(image, 0.5, 2.0)
... image = tf.image.random_saturation(image, 0.75, 1.25)
... image = tf.image.random_hue(image, 0.1)
... image = tf.transpose(image, (2, 0, 1))
... return image
>>> def transforms(image):
... image = tf.keras.utils.img_to_array(image)
... image = tf.transpose(image, (2, 0, 1))
... return image
```
次に、モデルの画像と注釈のバッチを準備する 2 つの前処理関数を作成します。これらの機能が適用されます
画像変換を行い、以前にロードされた `image_processor` を使用して画像を `pixel_values` に変換し、
`labels`への注釈。 `ImageProcessor` は、画像のサイズ変更と正規化も処理します。
```py
>>> def train_transforms(example_batch):
... images = [aug_transforms(x.convert("RGB")) for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = image_processor(images, labels)
... return inputs
>>> def val_transforms(example_batch):
... images = [transforms(x.convert("RGB")) for x in example_batch["image"]]
... labels = [x for x in example_batch["annotation"]]
... inputs = image_processor(images, labels)
... return inputs
```
データセット全体に前処理変換を適用するには、🤗 Datasets [`~datasets.Dataset.set_transform`] 関数を使用します。
変換はオンザフライで適用されるため、高速で消費するディスク容量が少なくなります。
```py
>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)
```
</tf>
</frameworkcontent>
## Evaluate
トレーニング中にメトリクスを含めると、多くの場合、モデルのパフォーマンスを評価するのに役立ちます。 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) ライブラリを使用して、評価メソッドをすばやくロードできます。このタスクでは、[Mean Intersection over Union](https://huggingface.co/spaces/evaluate-metric/accuracy) (IoU) メトリックをロードします (🤗 Evaluate [クイック ツアー](https://huggingface.co) を参照してください) /docs/evaluate/a_quick_tour) を参照して、メトリクスをロードして計算する方法の詳細を確認してください)。
```py
>>> import evaluate
>>> metric = evaluate.load("mean_iou")
```
次に、メトリクスを [`~evaluate.EvaluationModule.compute`] する関数を作成します。予測を次のように変換する必要があります
最初にロジットを作成し、次に [`~evaluate.EvaluationModule.compute`] を呼び出す前にラベルのサイズに一致するように再形成します。
<frameworkcontent>
<pt>
```py
>>> import numpy as np
>>> import torch
>>> from torch import nn
>>> def compute_metrics(eval_pred):
... with torch.no_grad():
... logits, labels = eval_pred
... logits_tensor = torch.from_numpy(logits)
... logits_tensor = nn.functional.interpolate(
... logits_tensor,
... size=labels.shape[-2:],
... mode="bilinear",
... align_corners=False,
... ).argmax(dim=1)
... pred_labels = logits_tensor.detach().cpu().numpy()
... metrics = metric.compute(
... predictions=pred_labels,
... references=labels,
... num_labels=num_labels,
... ignore_index=255,
... reduce_labels=False,
... )
... for key, value in metrics.items():
... if type(value) is np.ndarray:
... metrics[key] = value.tolist()
... return metrics
```
</pt>
</frameworkcontent>
<frameworkcontent>
<tf>
```py
>>> def compute_metrics(eval_pred):
... logits, labels = eval_pred
... logits = tf.transpose(logits, perm=[0, 2, 3, 1])
... logits_resized = tf.image.resize(
... logits,
... size=tf.shape(labels)[1:],
... method="bilinear",
... )
... pred_labels = tf.argmax(logits_resized, axis=-1)
... metrics = metric.compute(
... predictions=pred_labels,
... references=labels,
... num_labels=num_labels,
... ignore_index=-1,
... reduce_labels=image_processor.do_reduce_labels,
... )
... per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
... per_category_iou = metrics.pop("per_category_iou").tolist()
... metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
... metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
... return {"val_" + k: v for k, v in metrics.items()}
```
</tf>
</frameworkcontent>
これで`compute_metrics`関数の準備が整いました。トレーニングをセットアップするときにこの関数に戻ります。
## Train
<frameworkcontent>
<pt>
<Tip>
[`Trainer`] を使用したモデルの微調整に慣れていない場合は、[こちら](../training#finetune-with-trainer) の基本的なチュートリアルをご覧ください。
</Tip>
これでモデルのトレーニングを開始する準備が整いました。 [`AutoModelForSemanticSegmentation`] を使用して SegFormer をロードし、ラベル ID とラベル クラス間のマッピングをモデルに渡します。
```py
>>> from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer
>>> model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
```
この時点で残っている手順は次の 3 つだけです。
1. [`TrainingArguments`] でトレーニング ハイパーパラメータを定義します。 `image` 列が削除されるため、未使用の列を削除しないことが重要です。 `image` 列がないと、`pixel_values` を作成できません。この動作を防ぐには、`remove_unused_columns=False`を設定してください。他に必要なパラメータは、モデルの保存場所を指定する `output_dir` だけです。 `push_to_hub=True`を設定して、このモデルをハブにプッシュします (モデルをアップロードするには、Hugging Face にサインインする必要があります)。各エポックの終了時に、[`Trainer`] は IoU メトリックを評価し、トレーニング チェックポイントを保存します。
2. トレーニング引数を、モデル、データセット、トークナイザー、データ照合器、および `compute_metrics` 関数とともに [`Trainer`] に渡します。
3. [`~Trainer.train`] を呼び出してモデルを微調整します。
```py
>>> training_args = TrainingArguments(
... output_dir="segformer-b0-scene-parse-150",
... learning_rate=6e-5,
... num_train_epochs=50,
... per_device_train_batch_size=2,
... per_device_eval_batch_size=2,
... save_total_limit=3,
... eval_strategy="steps",
... save_strategy="steps",
... save_steps=20,
... eval_steps=20,
... logging_steps=1,
... eval_accumulation_steps=5,
... remove_unused_columns=False,
... push_to_hub=True,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... train_dataset=train_ds,
... eval_dataset=test_ds,
... compute_metrics=compute_metrics,
... )
>>> trainer.train()
```
トレーニングが完了したら、 [`~transformers.Trainer.push_to_hub`] メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できるようにします。
```py
>>> trainer.push_to_hub()
```
</pt>
</frameworkcontent>
<frameworkcontent>
<tf>
<Tip>
Keras を使用したモデルの微調整に慣れていない場合は、まず [基本チュートリアル](./training#train-a-tensorflow-model-with-keras) を確認してください。
</Tip>
TensorFlow でモデルを微調整するには、次の手順に従います。
1. トレーニングのハイパーパラメータを定義し、オプティマイザーと学習率スケジュールを設定します。
2. 事前トレーニングされたモデルをインスタンス化します。
3. 🤗 データセットを `tf.data.Dataset` に変換します。
4. モデルをコンパイルします。
5. コールバックを追加してメトリクスを計算し、モデルを 🤗 Hub にアップロードします
6. `fit()` メソッドを使用してトレーニングを実行します。
まず、ハイパーパラメーター、オプティマイザー、学習率スケジュールを定義します。
```py
>>> from transformers import create_optimizer
>>> batch_size = 2
>>> num_epochs = 50
>>> num_train_steps = len(train_ds) * num_epochs
>>> learning_rate = 6e-5
>>> weight_decay_rate = 0.01
>>> optimizer, lr_schedule = create_optimizer(
... init_lr=learning_rate,
... num_train_steps=num_train_steps,
... weight_decay_rate=weight_decay_rate,
... num_warmup_steps=0,
... )
```
次に、ラベル マッピングとともに [`TFAutoModelForSemanticSegmentation`] を使用して SegFormer をロードし、それをコンパイルします。
オプティマイザ。 Transformers モデルにはすべてデフォルトのタスク関連の損失関数があるため、次の場合を除き、損失関数を指定する必要はないことに注意してください。
```py
>>> from transformers import TFAutoModelForSemanticSegmentation
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
... checkpoint,
... id2label=id2label,
... label2id=label2id,
... )
>>> model.compile(optimizer=optimizer) # No loss argument!
```
[`~datasets.Dataset.to_tf_dataset`] と [`DefaultDataCollator`] を使用して、データセットを `tf.data.Dataset` 形式に変換します。
```py
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator(return_tensors="tf")
>>> tf_train_dataset = train_ds.to_tf_dataset(
... columns=["pixel_values", "label"],
... shuffle=True,
... batch_size=batch_size,
... collate_fn=data_collator,
... )
>>> tf_eval_dataset = test_ds.to_tf_dataset(
... columns=["pixel_values", "label"],
... shuffle=True,
... batch_size=batch_size,
... collate_fn=data_collator,
... )
```
予測から精度を計算し、モデルを 🤗 ハブにプッシュするには、[Keras callbacks](../main_classes/keras_callbacks) を使用します。
`compute_metrics` 関数を [`KerasMetricCallback`] に渡します。
そして [`PushToHubCallback`] を使用してモデルをアップロードします。
```py
>>> from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
>>> metric_callback = KerasMetricCallback(
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
... )
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
>>> callbacks = [metric_callback, push_to_hub_callback]
```
ついに、モデルをトレーニングする準備が整いました。`fit()`トレーニングおよび検証データセット、エポック数、
モデルを微調整するためのコールバック:
```py
>>> model.fit(
... tf_train_dataset,
... validation_data=tf_eval_dataset,
... callbacks=callbacks,
... epochs=num_epochs,
... )
```
おめでとう!モデルを微調整し、🤗 Hub で共有しました。これで推論に使用できるようになりました。
</tf>
</frameworkcontent>
## Inference
モデルを微調整したので、それを推論に使用できるようになりました。
推論のために画像をロードします。
```py
>>> image = ds[0]["image"]
>>> image
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-image.png" alt="Image of bedroom"/>
</div>
<frameworkcontent>
<pt>
推論用に微調整されたモデルを試す最も簡単な方法は、それを [`pipeline`] で使用することです。モデルを使用して画像セグメンテーション用の `pipeline` をインスタンス化し、それに画像を渡します。
```py
>>> from transformers import pipeline
>>> segmenter = pipeline("image-segmentation", model="my_awesome_seg_model")
>>> segmenter(image)
[{'score': None,
'label': 'wall',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062690>},
{'score': None,
'label': 'sky',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A50>},
{'score': None,
'label': 'floor',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062B50>},
{'score': None,
'label': 'ceiling',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A10>},
{'score': None,
'label': 'bed ',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E90>},
{'score': None,
'label': 'windowpane',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062390>},
{'score': None,
'label': 'cabinet',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062550>},
{'score': None,
'label': 'chair',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062D90>},
{'score': None,
'label': 'armchair',
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E10>}]
```
必要に応じて、`pipeline` の結果を手動で複製することもできます。画像プロセッサで画像を処理し、`pixel_values`を GPU に配置します。
```py
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use GPU if available, otherwise use a CPU
>>> encoding = image_processor(image, return_tensors="pt")
>>> pixel_values = encoding.pixel_values.to(device)
```
入力をモデルに渡し、「logits」を返します。
```py
>>> outputs = model(pixel_values=pixel_values)
>>> logits = outputs.logits.cpu()
```
次に、ロジットを元の画像サイズに再スケールします。
```py
>>> upsampled_logits = nn.functional.interpolate(
... logits,
... size=image.size[::-1],
... mode="bilinear",
... align_corners=False,
... )
>>> pred_seg = upsampled_logits.argmax(dim=1)[0]
```
</pt>
</frameworkcontent>
<frameworkcontent>
<tf>
画像プロセッサをロードして画像を前処理し、入力を TensorFlow テンソルとして返します。
```py
>>> from transformers import AutoImageProcessor
>>> image_processor = AutoImageProcessor.from_pretrained("MariaK/scene_segmentation")
>>> inputs = image_processor(image, return_tensors="tf")
```
入力をモデルに渡し、`logits`を返します。
```py
>>> from transformers import TFAutoModelForSemanticSegmentation
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("MariaK/scene_segmentation")
>>> logits = model(**inputs).logits
```
次に、ロジットを元の画像サイズに再スケールし、クラス次元に argmax を適用します。
```py
>>> logits = tf.transpose(logits, [0, 2, 3, 1])
>>> upsampled_logits = tf.image.resize(
... logits,
... # We reverse the shape of `image` because `image.size` returns width and height.
... image.size[::-1],
... )
>>> pred_seg = tf.math.argmax(upsampled_logits, axis=-1)[0]
```
</tf>
</frameworkcontent>
結果を視覚化するには、[データセット カラー パレット](https://github.com/tensorflow/models/blob/3f1ca33afe3c1631b733ea7e40c294273b9e406d/research/deeplab/utils/get_dataset_colormap.py#L51) を、それぞれをマップする `ade_palette()` としてロードします。クラスを RGB 値に変換します。次に、画像と予測されたセグメンテーション マップを組み合わせてプロットできます。
```py
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
>>> palette = np.array(ade_palette())
>>> for label, color in enumerate(palette):
... color_seg[pred_seg == label, :] = color
>>> color_seg = color_seg[..., ::-1] # convert to BGR
>>> img = np.array(image) * 0.5 + color_seg * 0.5 # plot the image with the segmentation map
>>> img = img.astype(np.uint8)
>>> plt.figure(figsize=(15, 10))
>>> plt.imshow(img)
>>> plt.show()
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-preds.png" alt="Image of bedroom overlaid with segmentation map"/>
</div>

View File

@ -221,7 +221,7 @@ Transformerは最初に機械翻訳のために設計され、それ以降、ほ
事前訓練済みモデルをテキスト分類に使用するには、ベースのBERTモデルの上にシーケンス分類ヘッドを追加します。シーケンス分類ヘッドは最終的な隠れた状態を受け入れ、それらをロジットに変換するための線形層です。クロスエントロピー損失は、ロジットとターゲット間で最も可能性の高いラベルを見つけるために計算されます。
テキスト分類を試してみる準備はできましたかDistilBERTを微調整し、推論に使用する方法を学ぶために、完全な[テキスト分類ガイド](tasks/sequence_classification)をチェックしてみてください!
テキスト分類を試してみる準備はできましたかDistilBERTを微調整し、推論に使用する方法を学ぶために、完全な[テキスト分類ガイド(英語版)](../en/tasks/sequence_classification)をチェックしてみてください!
### Token classification

View File

@ -157,5 +157,8 @@
title: 通用工具
- local: internal/time_series_utils
title: 时序数据工具
- sections:
- local: model_doc/bert
title: BERT
title: 内部辅助工具
title: 应用程序接口 (API)
title: 应用程序接口 (API)

View File

@ -0,0 +1,258 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>
# BERT
[BERT](https://huggingface.co/papers/1810.04805) 是一个在无标签的文本数据上预训练的双向 transformer用于预测句子中被掩码的masked token以及预测一个句子是否跟随在另一个句子之后。其主要思想是在预训练过程中通过随机掩码一些 token让模型利用左右上下文的信息预测它们从而获得更全面深入的理解。此外BERT 具有很强的通用性,其学习到的语言表示可以通过额外的层或头进行微调,从而适配其他下游 NLP 任务。
你可以在 [BERT](https://huggingface.co/collections/google/bert-release-64ff5e7a4be99045d1896dbc) 集合下找到 BERT 的所有原始 checkpoint。
> [!TIP]
> 点击右侧边栏中的 BERT 模型,以查看将 BERT 应用于不同语言任务的更多示例。
下面的示例演示了如何使用 [`Pipeline`], [`AutoModel`] 和命令行预测 `[MASK]` token。
<hfoptions id="usage">
<hfoption id="Pipeline">
```py
import torch
from transformers import pipeline
pipeline = pipeline(
task="fill-mask",
model="google-bert/bert-base-uncased",
torch_dtype=torch.float16,
device=0
)
pipeline("Plants create [MASK] through a process known as photosynthesis.")
```
</hfoption>
<hfoption id="AutoModel">
```py
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"google-bert/bert-base-uncased",
)
model = AutoModelForMaskedLM.from_pretrained(
"google-bert/bert-base-uncased",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa"
)
inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
predicted_token = tokenizer.decode(predicted_token_id)
print(f"The predicted token is: {predicted_token}")
```
</hfoption>
<hfoption id="transformers-cli">
```bash
echo -e "Plants create [MASK] through a process known as photosynthesis." | transformers-cli run --task fill-mask --model google-bert/bert-base-uncased --device 0
```
</hfoption>
</hfoptions>
## 注意
- 输入内容应在右侧进行填充,因为 BERT 使用绝对位置嵌入。
## BertConfig
[[autodoc]] BertConfig
- all
## BertTokenizer
[[autodoc]] BertTokenizer
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary
## BertTokenizerFast
[[autodoc]] BertTokenizerFast
## BertModel
[[autodoc]] BertModel
- forward
## BertForPreTraining
[[autodoc]] BertForPreTraining
- forward
## BertLMHeadModel
[[autodoc]] BertLMHeadModel
- forward
## BertForMaskedLM
[[autodoc]] BertForMaskedLM
- forward
## BertForNextSentencePrediction
[[autodoc]] BertForNextSentencePrediction
- forward
## BertForSequenceClassification
[[autodoc]] BertForSequenceClassification
- forward
## BertForMultipleChoice
[[autodoc]] BertForMultipleChoice
- forward
## BertForTokenClassification
[[autodoc]] BertForTokenClassification
- forward
## BertForQuestionAnswering
[[autodoc]] BertForQuestionAnswering
- forward
## TFBertTokenizer
[[autodoc]] TFBertTokenizer
## TFBertModel
[[autodoc]] TFBertModel
- call
## TFBertForPreTraining
[[autodoc]] TFBertForPreTraining
- call
## TFBertModelLMHeadModel
[[autodoc]] TFBertLMHeadModel
- call
## TFBertForMaskedLM
[[autodoc]] TFBertForMaskedLM
- call
## TFBertForNextSentencePrediction
[[autodoc]] TFBertForNextSentencePrediction
- call
## TFBertForSequenceClassification
[[autodoc]] TFBertForSequenceClassification
- call
## TFBertForMultipleChoice
[[autodoc]] TFBertForMultipleChoice
- call
## TFBertForTokenClassification
[[autodoc]] TFBertForTokenClassification
- call
## TFBertForQuestionAnswering
[[autodoc]] TFBertForQuestionAnswering
- call
## FlaxBertModel
[[autodoc]] FlaxBertModel
- __call__
## FlaxBertForPreTraining
[[autodoc]] FlaxBertForPreTraining
- __call__
## FlaxBertForCausalLM
[[autodoc]] FlaxBertForCausalLM
- __call__
## FlaxBertForMaskedLM
[[autodoc]] FlaxBertForMaskedLM
- __call__
## FlaxBertForNextSentencePrediction
[[autodoc]] FlaxBertForNextSentencePrediction
- __call__
## FlaxBertForSequenceClassification
[[autodoc]] FlaxBertForSequenceClassification
- __call__
## FlaxBertForMultipleChoice
[[autodoc]] FlaxBertForMultipleChoice
- __call__
## FlaxBertForTokenClassification
[[autodoc]] FlaxBertForTokenClassification
- __call__
## FlaxBertForQuestionAnswering
[[autodoc]] FlaxBertForQuestionAnswering
- __call__
## Bert specific outputs
[[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput
[[autodoc]] models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
[[autodoc]] models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput

435
examples/3D_parallel.py Normal file
View File

@ -0,0 +1,435 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""":
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
Usage:
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=5,6,7
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
DP_SIZE=2 CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=8 examples/3D_parallel.py
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=localhost:29504 examples/3D_parallel.py
ocalhost:29504 test_train.py
"""
import logging
import os
from contextlib import nullcontext
from typing import Iterable
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.optim as optim
import wandb
from datasets import load_dataset
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
# from torch.distributed.tensor.experimental._attention import set_rotate_method
# set_rotate_method("alltoall") # CP rotate shards using all-to-all
def main():
tp_size = int(os.environ.get("TP_SIZE", 1))
dp_size = int(os.environ.get("DP_SIZE", 1))
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
# sdpa_backend = SDPBackend.MATH # For CP
global_batch_size = 8 # Desired global batch size
seq_len = 1024 # Sequence length
num_train_steps = 10000 # Number of training steps
LR = 1e-5
model_name = "HuggingFaceTB/SmolLM2-1.7B"
# model_name = "unsloth/Llama-3.2-1B"
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
# Initialize distributed environment
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
assert world_size == tp_size * dp_size * cp_size, (
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
)
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"]
dp_mesh = world_mesh["dp"]
cp_mesh = world_mesh["cp"]
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
logger.info(f"Created DeviceMesh: {world_mesh}")
logger.info(
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
)
if dist.get_rank() == 0:
wandb.init(
project="tp_dp_test",
config={
"tp_size": tp_size,
"dp_size": dp_size,
"cp_size": cp_size,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
"lr": LR,
"weight_decay": 0.1,
},
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
if model_name == "unsloth/Llama-3.2-1B"
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
)
logger.info("Wandb initialized.")
# Log the current file to wandb
wandb.save("test_train.py")
# Load model and tokenizer
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh if dist.is_initialized() else None,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
device = torch.device(f"cuda:{local_rank}")
logger.info(f"Using device: {device} for non-model tensors")
use_ddp = False
if dist.is_initialized() and dp_mesh.size() > 1:
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
use_ddp = True
pass
model.train()
logger.info("Loading TinyStories dataset...")
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
def tokenize_function(examples):
# Tokenize the text without padding
tokenized_batch = tokenizer(
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
)
# Set labels to be the same as input_ids for Causal LM
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
# Create packed sequences
def create_packed_sequences(examples):
# Flatten all sequences
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
# Split into sequences of seq_len + 1 (for input + label)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
# Get the full sequence
full_sequence = all_tokens[start_idx:end_idx]
# For input_ids, remove the last token
packed_input_ids.append(full_sequence[:-1])
# For labels, remove the first token
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
# Apply packing to the dataset
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000, # Process in batches for efficiency
num_proc=60,
)
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
# Shuffle the packed dataset
packed_dataset = packed_dataset.shuffle(seed=42)
logger.info("Packed dataset shuffled")
# Calculate local batch size
if dist.is_initialized():
assert global_batch_size % dp_mesh.size() == 0, (
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
)
local_batch_size = global_batch_size // dp_mesh.size()
else:
local_batch_size = global_batch_size
logger.info(
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
)
# Simple collate function since sequences are already packed
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
if dist.is_initialized():
sampler = DistributedSampler(
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
)
else:
sampler = None
dataloader = DataLoader(
packed_dataset,
batch_size=local_batch_size,
sampler=sampler,
shuffle=False,
collate_fn=collate_fn,
pin_memory=True,
)
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
# Training loop
logger.info(f"Starting training for {num_train_steps} steps...")
model.train()
step = 0
while step < num_train_steps:
for batch in dataloader:
if step >= num_train_steps:
break # Exit loop if max steps reached
# Move batch to appropriate device
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
# Add position_ids to batch before CP sharding
batch_size = batch["input_ids"].shape[0]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
batch["position_ids"] = position_ids
from torch.distributed.tensor.experimental._attention import _cp_options
_cp_options.enable_load_balance = False
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
cp_context = (
nullcontext()
if cp_mesh.size() == 1
else context_parallel(
cp_mesh,
buffers=[
batch["input_ids"],
batch["labels"],
batch["position_ids"],
],
buffer_seq_dims=[1, 1, 1],
)
)
with cp_context:
# Pop labels from batch before model forward pass
labels = batch.pop("labels")
outputs = model(**batch) # [mbs, seq_len/cp]
loss = outputs.loss
logits = outputs.logits
# Compute loss with shifted labels
loss = model.loss_function(
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
)
loss.backward()
# all reduce grads across dp_cp if applicable
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
if hasattr(model, "clip_grad_norm_"):
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) # TODO: fix reported gradnorm
else:
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
assert len(list(model.parameters())) > 5, "No parameters found in model. Probably DDP bug.."
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
optimizer.step()
# allreduce loss across cp_dp before logging
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
current_loss = loss.item()
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
)
wandb.log(
{
"train/loss": current_loss,
"train/gradnorm": gradnorm,
"step": step,
"lr": LR,
"GBS": global_batch_size,
}
)
step += 1 # Increment step count
logger.info("Training loop finished.")
# Save model using DCP (only if distributed)
if dist.is_initialized():
state_dict = {"app": AppState(model, optimizer)}
dcp.save(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
else:
# Fallback to regular save for non-distributed case
save_dir = "test_model_nondist"
model.save_pretrained(save_dir, safe_serialization=False)
tokenizer.save_pretrained(save_dir) # Save tokenizer too
logger.info(f"Saved model to {save_dir}")
dist.destroy_process_group()
logger.info("Cleaned up distributed process group")
# Finish wandb run on rank 0
if dist.get_rank() == 0:
wandb.finish()
logger.info("Wandb run finished.")
def all_reduce_grads(model, world_mesh, use_ddp):
"""All reduce gradients across dp_cp if applicable."""
cp_mesh = world_mesh["cp"]
if use_ddp:
# DDP/FSDP takes care of syncing grads
mesh = cp_mesh
else:
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
if dist.is_initialized() and mesh.size() > 1:
for name, param in model.named_parameters():
if param.grad is not None:
# Workaround for cross-mesh communication limitation with DTensor gradients
if isinstance(param.grad, DTensor):
local_grad = param.grad.to_local()
# Ensure grad requires grad for inplace modification checks (might not be needed)
# local_grad = local_grad.detach().requires_grad_(True)
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
local_grad = local_grad / mesh.size()
# Assign averaged grad back - need careful handling if DTensor structure is complex
# This simple assignment might work if the grad structure matches param structure
param.grad = DTensor.from_local(
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
)
else:
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
class AppState(Stateful):
"""Wrapper for checkpointing the Application State including model and optimizer."""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
)
def clip_grad_norm_(
parameters: Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: bool | None = None,
) -> torch.Tensor:
"""
Clip the gradient norm of an iterable of parameters.
"""
# Filter out parameters with no gradients
parameters = [p for p in parameters if p.grad is not None]
assert len(parameters) > 0, "No parameters with gradients found"
# Calculate total norm
if norm_type == float("inf"):
total_norm = max(p.grad.detach().abs().max() for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
# Convert DTensor to local tensor if needed
if isinstance(total_norm, DTensor):
total_norm = total_norm.full_tensor()
# Clip gradients
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
if __name__ == "__main__":
main()

View File

@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -0,0 +1,793 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""":
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
Usage:
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=5,6,7
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=l
ocalhost:29504 test_train.py
"""
import logging
import os
from contextlib import nullcontext
from typing import Dict, Iterable, Optional
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torch.optim as optim
import wandb
from datasets import load_dataset
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils.data import DataLoader, default_collate
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
ignore_sanity_checks = int(os.environ.get("IGNORE_SANITY", 0)) == 1
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
# from torch.distributed.tensor.experimental._attention import set_rotate_method
# set_rotate_method("alltoall") # rotate shards using all-to-all
def main():
tp_size = int(os.environ.get("TP_SIZE", 1))
dp_size = int(os.environ.get("DP_SIZE", 4))
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
# sdpa_backend = SDPBackend.MATH # For CP
global_batch_size = 8 # Desired global batch size
seq_len = 1024 # Sequence length
num_train_steps = 10000 # Number of training steps
LR = 1e-5
model_name = "HuggingFaceTB/SmolLM2-1.7B"
# model_name = "unsloth/Llama-3.2-1B"
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
# Initialize distributed environment
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
assert world_size == tp_size * dp_size * cp_size, (
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
)
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"]
dp_mesh = world_mesh["dp"]
cp_mesh = world_mesh["cp"]
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
logger.info(f"Created DeviceMesh: {world_mesh}")
logger.info(
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
)
if dist.get_rank() == 0:
wandb.init(
project="tp_dp_test",
config={
"tp_size": tp_size,
"dp_size": dp_size,
"cp_size": cp_size,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
"lr": LR,
"weight_decay": 0.1,
},
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
if model_name == "unsloth/Llama-3.2-1B"
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
)
logger.info(f"ignore_sanity_checks is set to: {ignore_sanity_checks}")
logger.info("Wandb initialized.")
# Log the current file to wandb
wandb.save("test_train.py")
else:
logger.info("Running in non-distributed mode. DeviceMesh not applicable.")
rank = 0
world_size = 1
local_rank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.init(
project="tp_dp_test",
config={
"tp_size": 1,
"dp_size": 1,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
},
name="llama_tp1_dp1_nondist" if model_name == "unsloth/Llama-3.2-1B" else "tp1_dp1_nondist",
)
logger.info("Wandb initialized for non-distributed run.")
# Load model and tokenizer
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh if dist.is_initialized() else None,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
if dist.is_initialized():
assert model.config.num_key_value_heads % tp_mesh.size() == 0, (
f"num_key_value_heads={model.config.num_key_value_heads} must be divisible by tp_size={tp_mesh.size()}"
)
device = torch.device(f"cuda:{local_rank}")
else:
model = model.to(device)
logger.info(f"Using device: {device} for non-model tensors")
use_ddp = False
if dist.is_initialized() and dp_mesh.size() > 1:
# FSDP1
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
# FSDP2
# for transformer_block in model.model.layers:
# fully_shard(transformer_block, mesh=dp_mesh, reshard_after_forward=False)
# fully_shard(model.model, mesh=dp_mesh, reshard_after_forward=False)
# DDP
# replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
# assert len(list(model.parameters()))>5, "No parameters found in model. Probably DDP/FSDP bug.." # TODO: we should be cautious abt using model.parameters()
use_ddp = True
model.train()
assert len(list(model.parameters())) > 0, "No parameters found in model. Probably DDP bug.."
assert len([p for p in model.parameters() if p.requires_grad]) > 0, (
"No gradients found in model. Probably DDP bug.."
)
if dist.is_initialized() and not ignore_sanity_checks:
# assert model is replicated across all dp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, dp_mesh)
# assert model is different across tp (only for sharded params)
for name, param in model.named_parameters():
if isinstance(param, DTensor) and param.placements[0].is_shard():
# Only check sharded parameters for non-sync across TP
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
elif isinstance(param, DTensor) and param.placements[0].is_replicate():
# Replicated parameters should be the same across TP
sanity_check_tensor_sync(param, tp_mesh)
# assert model is replicated across cp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, cp_mesh)
# Load and preprocess TinyStories dataset
logger.info("Loading TinyStories dataset...")
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
def tokenize_function(examples):
# Tokenize the text without padding
tokenized_batch = tokenizer(
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
)
# Set labels to be the same as input_ids for Causal LM
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
# Create packed sequences
def create_packed_sequences(examples):
# Flatten all sequences
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
# Split into sequences of seq_len + 1 (for input + label)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
# Get the full sequence
full_sequence = all_tokens[start_idx:end_idx]
# For input_ids, remove the last token
packed_input_ids.append(full_sequence[:-1])
# For labels, remove the first token
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
# Apply packing to the dataset
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000, # Process in batches for efficiency
num_proc=60,
)
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
# Shuffle the packed dataset
packed_dataset = packed_dataset.shuffle(seed=42)
logger.info("Packed dataset shuffled")
# Calculate local batch size
if dist.is_initialized():
assert global_batch_size % dp_mesh.size() == 0, (
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
)
local_batch_size = global_batch_size // dp_mesh.size()
else:
local_batch_size = global_batch_size
logger.info(
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
)
# Simple collate function since sequences are already packed
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
if dist.is_initialized():
sampler = DistributedSampler(
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
)
else:
sampler = None
dataloader = DataLoader(
packed_dataset,
batch_size=local_batch_size,
sampler=sampler,
shuffle=False,
collate_fn=collate_fn,
)
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
# Training loop
logger.info(f"Starting training for {num_train_steps} steps...")
model.train()
step = 0
while step < num_train_steps:
for batch in dataloader:
if step >= num_train_steps:
break # Exit loop if max steps reached
# Move batch to appropriate device
batch = {k: v.to(device) for k, v in batch.items()}
# Sanity checks for batch distribution (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check batch is same across all tp
sanity_check_tensor_sync(batch["input_ids"], tp_mesh)
# check batch is different across dp
sanity_check_tensor_sync(batch["input_ids"], dp_mesh, not_sync=True)
optimizer.zero_grad()
# Add position_ids to batch before CP sharding
batch_size = batch["input_ids"].shape[0]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
batch["position_ids"] = position_ids
from torch.distributed.tensor.experimental._attention import _cp_options
_cp_options.enable_load_balance = False
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
cp_context = (
nullcontext()
if cp_mesh.size() == 1
else context_parallel(
cp_mesh,
buffers=[
batch["input_ids"],
batch["labels"],
batch["position_ids"],
], # TODO: need to add attention mask
buffer_seq_dims=[1, 1, 1],
)
)
with cp_context:
# Pop labels from batch before model forward pass
labels = batch.pop("labels")
outputs = model(**batch) # [mbs, seq_len/cp]
loss = outputs.loss
logits = outputs.logits
# Compute loss with shifted labels
loss = model.loss_function(
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
)
# Sanity checks for logits
if dist.is_initialized() and not ignore_sanity_checks:
# sanity_check_tensor_sync(logits, tp_mesh) # TODO: only true without sequence parallel
sanity_check_tensor_sync(logits, dp_mesh, not_sync=True)
sanity_check_tensor_sync(logits, cp_mesh, not_sync=True)
loss.backward()
# all reduce grads across dp_cp if applicable
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
# Sanity checks for gradients (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check grads are not same across all tp (for sharded grads)
for name, param in model.named_parameters():
if param.grad is not None and isinstance(param.grad, DTensor):
if param.grad.placements[0].is_shard():
sanity_check_tensor_sync(param.grad, tp_mesh, not_sync=True)
elif param.grad.placements[0].is_replicate():
sanity_check_tensor_sync(param.grad, tp_mesh)
# check grads are same across dp
for name, param in model.named_parameters():
if param.grad is not None and dp_mesh.size() > 1:
sanity_check_tensor_sync(param.grad, dp_mesh)
# check grads are same across cp
for name, param in model.named_parameters():
if param.grad is not None and cp_mesh.size() > 1:
sanity_check_tensor_sync(param.grad, cp_mesh)
# Calculate gradient norm and clip gradients
if hasattr(model, "clip_grad_norm_"):
# when using FSDP or DDP, model.parameters() doesn't work
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0)
else:
assert len(list(model.parameters())) > 2, "No parameters found in model. Probably DDP bug.."
assert len([p for p in model.parameters() if p.requires_grad]) > 2, (
"No gradients found in model. Probably DDP bug.."
)
assert len([p for p in model.parameters() if p.grad is not None]) > 2, (
"No gradients found in model. Probably DDP bug.."
)
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
optimizer.step()
# Sanity checks for updated model parameters (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check updated model is different across all tp (for sharded params)
for name, param in model.named_parameters():
if isinstance(param, DTensor):
if param.placements[0].is_shard():
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
elif param.placements[0].is_replicate():
sanity_check_tensor_sync(param, tp_mesh)
# check updated model is same across dp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, dp_mesh)
# check updated model is same across cp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, cp_mesh)
# allreduce loss across cp_dp before logging
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
current_loss = loss.item()
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
)
wandb.log(
{
"train/loss": current_loss,
"train/gradnorm": gradnorm,
"step": step,
"lr": LR,
"GBS": global_batch_size,
}
)
step += 1 # Increment step count
logger.info("Training loop finished.")
# Save model using DCP (only if distributed)
if dist.is_initialized():
state_dict = {"app": AppState(model, optimizer)}
dcp.save(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
else:
# Fallback to regular save for non-distributed case
save_dir = "test_model_nondist"
model.save_pretrained(save_dir, safe_serialization=False)
tokenizer.save_pretrained(save_dir) # Save tokenizer too
logger.info(f"Saved model to {save_dir}")
# Example of loading the checkpoint (only if distributed)
if dist.is_initialized():
# Create a new model instance
logger.info("Creating new model instance for verification")
new_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh,
torch_dtype=torch.bfloat16, # Use same dtype
)
new_optimizer = optim.AdamW(new_model.parameters(), lr=LR)
# Load checkpoint into new model
state_dict = {"app": AppState(new_model, new_optimizer)}
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info("Loaded checkpoint into new model")
# Verify model weights match
logger.info("Verifying model weights match...")
for (name1, param1), (name2, param2) in zip(model.named_parameters(), new_model.named_parameters()):
torch.testing.assert_close(
param1.to_local(),
param2.to_local(),
rtol=1e-3,
atol=1e-3,
msg=f"Weights mismatch in {name1} vs {name2}",
)
# Verify optimizer states match
logger.info("Verifying optimizer states match...")
for name1, state1 in optimizer.state_dict().items():
state2 = new_optimizer.state_dict()[name1]
if name1 == "state":
# Compare state dictionaries for each parameter
for param_id, param_state1 in state1.items():
param_state2 = state2[param_id]
# Compare each state component (step, exp_avg, exp_avg_sq)
for key, value1 in param_state1.items():
value2 = param_state2[key]
if isinstance(value1, DTensor):
# Convert DTensors to local tensors for comparison
torch.testing.assert_close(
value1.to_local(),
value2.to_local(),
rtol=1e-5,
atol=1e-5,
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
)
else:
torch.testing.assert_close(
value1,
value2,
rtol=1e-5,
atol=1e-5,
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
)
elif name1 == "param_groups":
# Compare param_groups (excluding the actual params list)
for i, (group1, group2) in enumerate(zip(state1, state2)):
for key in group1:
if key != "params": # Skip comparing the params list
assert group1[key] == group2[key], f"Param group mismatch in param_groups[{i}][{key}]"
# Run a forward pass with both models to verify outputs match
logger.info("Running forward pass verification...")
with torch.no_grad():
# Use the last batch for verification
batch = {k: v.to(device) for k, v in batch.items()} # Ensure batch is on correct device
original_outputs = model(**batch)
new_outputs = new_model(**batch)
torch.testing.assert_close(
original_outputs.logits.to_local(),
new_outputs.logits.to_local(),
rtol=1e-3,
atol=1e-3,
msg="Model outputs do not match!",
) # Increased tolerance slightly for bf16
# Clean up distributed environment and finish wandb run
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Cleaned up distributed process group")
# Finish wandb run on rank 0
if dist.get_rank() == 0:
wandb.finish()
logger.info("Wandb run finished.")
else:
wandb.finish()
logger.info("Wandb run finished.")
def all_reduce_grads(model, world_mesh, use_ddp):
"""All reduce gradients across dp_cp if applicable."""
cp_mesh = world_mesh["cp"]
if use_ddp:
# DDP takes care of syncing grads
mesh = cp_mesh
else:
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
if dist.is_initialized() and mesh.size() > 1:
for name, param in model.named_parameters():
if param.grad is not None:
# Workaround for cross-mesh communication limitation with DTensor gradients
if isinstance(param.grad, DTensor):
local_grad = param.grad.to_local()
# Ensure grad requires grad for inplace modification checks (might not be needed)
# local_grad = local_grad.detach().requires_grad_(True)
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
local_grad = local_grad / mesh.size()
# Assign averaged grad back - need careful handling if DTensor structure is complex
# This simple assignment might work if the grad structure matches param structure
param.grad = DTensor.from_local(
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
)
else:
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
class ContextParallelCollator:
"""Collator for context parallel training that splits sequences into chunks."""
def __init__(self, cp_mesh: Optional[DeviceMesh] = None):
self.cp_mesh = cp_mesh
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = default_collate(batch)
if self.cp_mesh is not None and self.cp_mesh.size() > 1:
# Get sequence length from the input batch
seq_len = batch["input_ids"].shape[1]
assert seq_len % self.cp_mesh.size() == 0, (
f"Sequence length {seq_len} must be divisible by CP size {self.cp_mesh.size()}"
)
chunk_size = seq_len // self.cp_mesh.size()
cp_rank = self.cp_mesh.get_local_rank()
start_idx = cp_rank * chunk_size
end_idx = start_idx + chunk_size
# Keep only the local chunk of the sequence
batch["input_ids"] = batch["input_ids"][:, start_idx:end_idx]
batch["attention_mask"] = batch["attention_mask"][:, start_idx:end_idx]
batch["labels"] = batch["labels"][:, start_idx:end_idx]
return batch
class AppState(Stateful):
"""Wrapper for checkpointing the Application State including model and optimizer."""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
)
def sanity_check_tensor_sync(
tensor: torch.Tensor, mesh: DeviceMesh, rtol: float = 1e-4, atol: float = 1e-4, not_sync: bool = False
) -> None:
"""
Verify that a tensor is synchronized (or not synchronized) across all processes in the mesh's process group.
Handles both regular tensors and DTensors.
Args:
tensor (torch.Tensor): The tensor to check for synchronization (can be DTensor)
mesh (DeviceMesh): The device mesh containing the process group
rtol (float): Relative tolerance for comparison
atol (float): Absolute tolerance for comparison
not_sync (bool): If True, asserts that tensors are NOT synchronized. If False, asserts they are synchronized.
"""
if not dist.is_initialized() or mesh.size() == 1:
return # No need to check in non-distributed mode
# Get the process group from the mesh
pg = mesh.get_group()
# Convert DTensor to local tensor if needed
if hasattr(tensor, "to_local"):
local_tensor = tensor.to_local()
else:
local_tensor = tensor
# Gather tensors from all processes
world_size = dist.get_world_size(pg)
gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, local_tensor, group=pg)
# Compare each tensor with the first one
for i in range(1, world_size):
try:
torch.testing.assert_close(gathered_tensors[0], gathered_tensors[i], rtol=rtol, atol=atol)
except AssertionError as e:
if not_sync:
continue
# # Add detailed debugging for logit synchronization issues
# print(f"\nLogit synchronization error between rank 0 and rank {i}:")
# print(f"Tensor shape: {gathered_tensors[0].shape}")
# print(f"Number of mismatched elements: {(gathered_tensors[0] != gathered_tensors[i]).sum()}")
# print(f"Percentage of mismatched elements: {((gathered_tensors[0] != gathered_tensors[i]).sum() / gathered_tensors[0].numel() * 100):.2f}%")
# # Find the first few mismatches
# mismatches = torch.nonzero(gathered_tensors[0] != gathered_tensors[i])
# print("\nFirst few mismatches:")
# for idx in mismatches[:5]:
# idx = tuple(idx.tolist())
# print(f"Index {idx}:")
# print(f"Rank 0 value: {gathered_tensors[0][idx]}")
# print(f"Rank {i} value: {gathered_tensors[i][idx]}")
# print(f"Absolute difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx])}")
# print(f"Relative difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx]) / max(abs(gathered_tensors[0][idx]), abs(gathered_tensors[i][idx]))}")
# # Check if differences are systematic (e.g., all positive or negative)
# diff = gathered_tensors[0] - gathered_tensors[i]
# print(f"\nDifference statistics:")
# print(f"Mean difference: {diff.mean()}")
# print(f"Std difference: {diff.std()}")
# print(f"Max positive difference: {diff.max()}")
# print(f"Max negative difference: {diff.min()}")
raise e
def clip_grad_norm_(
parameters: Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: bool | None = None,
) -> torch.Tensor:
"""
Clip the gradient norm of an iterable of parameters.
"""
# Filter out parameters with no gradients
parameters = [p for p in parameters if p.grad is not None]
assert len(parameters) > 0, "No parameters with gradients found"
# Calculate total norm
if norm_type == float("inf"):
total_norm = max(p.grad.detach().abs().max() for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
# Convert DTensor to local tensor if needed
if isinstance(total_norm, DTensor):
total_norm = total_norm.full_tensor()
# Clip gradients
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
def check_params_sync(model_params, original_params):
"""
Check if original_params are being updated in sync with model parameters.
Args:
model_params: Iterator of model parameters after update
original_params: List of original parameters before DDP wrapping
"""
for mp, op in zip(model_params, original_params):
if isinstance(mp, DTensor):
mp = mp.to_local()
if isinstance(op, DTensor):
op = op.to_local()
if not torch.allclose(mp.data, op.data, rtol=0, atol=0):
raise RuntimeError(f"Parameters out of sync: model param {mp.data} != original param {op.data}")
return True
def get_parameters(model: nn.Module) -> Iterable[torch.Tensor]:
"""
Get all parameters from a model by iterating over its modules.
This is an alternative to model.parameters() that works with DTensor models.
Args:
model (nn.Module): The model to get parameters from
Returns:
Iterable[torch.Tensor]: An iterator over all parameters in the model
"""
for name, module in model._modules.items():
# Look for parameters in module attributes
for attr_name, attr in module.__dict__.items():
if isinstance(attr, torch.Tensor) and attr.requires_grad:
yield attr
# Recursively get parameters from submodules
for param in get_parameters(module):
yield param
def update_model_parameters(model: nn.Module) -> None:
"""
Update model._parameters using named_modules() to ensure all parameters are properly tracked.
Args:
model (nn.Module): The model to update parameters for
"""
# Clear existing parameters
model._parameters = {}
# Add parameters from named_modules
for name, module in model.named_modules():
# Skip the root module itself
if name == "":
continue
# Get the parameter name by removing 'module.' prefix if it exists
param_name = name.replace("module.", "")
# Add weight and bias parameters if they exist
if hasattr(module, "weight") and module.weight is not None:
model._parameters[f"{param_name}.weight"] = module.weight
if hasattr(module, "bias") and module.bias is not None:
model._parameters[f"{param_name}.bias"] = module.bias
if __name__ == "__main__":
main()

View File

@ -44,7 +44,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -0,0 +1,94 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss
world_size = int(os.environ.get("WORLD_SIZE", "1"))
cp_mesh = init_device_mesh("cuda", (world_size,))
rank = torch.distributed.get_node_local_rank()
device = "cuda"
dtype = torch.bfloat16
sdpa_backend = SDPBackend.FLASH_ATTENTION
# prepare inputs
batch_size = 1
seq_len = 128
input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device)
ignore_index = -100
# When using CP, we need to use `shift_labels`
shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index)
shift_labels = shift_labels[..., 1:].contiguous()
position_ids = (
torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1
)
# sync input as they are created randomly
dist.broadcast(input_ids, src=0)
dist.broadcast(shift_labels, src=0)
dist.broadcast(position_ids, src=0)
# model and optimizer
repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model.train()
model.zero_grad()
optimizer.zero_grad()
# For loss
vocab_size = model.config.vocab_size
# so training could be synced
model = DDP(model, device_ids=[rank])
# prepare for CP
buffers = (input_ids, shift_labels, position_ids)
buffer_seq_dims = (1, 1, 1)
# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`.
# no_restore_buffers = set(buffers)
no_restore_buffers = None
# run with CP
with sdpa_kernel(sdpa_backend):
with context_parallel(
cp_mesh,
buffers=buffers,
buffer_seq_dims=buffer_seq_dims,
no_restore_buffers=no_restore_buffers,
):
outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)
print(outputs.logits.shape)
# So far we need to compute `loss` outside `model.forward` when using `shift_labels`
# loss = outputs.loss
loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)
# This could be outside `context_parallel` context if `no_restore_buffers` is specified
loss.backward()
optimizer.step()

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -42,7 +42,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -47,7 +47,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -52,7 +52,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -34,7 +34,6 @@ from transformers import (
GPT2Tokenizer,
GPTJForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
OPTForCausalLM,
@ -63,7 +62,7 @@ MODEL_CLASSES = {
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
"gptj": (GPTJForCausalLM, AutoTokenizer),
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"llama": (LlamaForCausalLM, AutoTokenizer),
"opt": (OPTForCausalLM, GPT2Tokenizer),
}

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
task_to_keys = {
"cola": ("sentence", None),

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -125,7 +125,7 @@ _deps = [
"jaxlib>=0.4.1,<=0.4.13",
"jieba",
"jinja2>=3.1.0",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
"kenlm",
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
"keras>2.9,<2.16",
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
@ -315,7 +315,7 @@ extras["audio"] = deps_list(
"librosa",
"pyctcdecode",
"phonemizer",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
"kenlm",
)
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["speech"] = deps_list("torchaudio") + extras["audio"]
@ -451,7 +451,7 @@ install_requires = [
setup(
name="transformers",
version="4.52.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.53.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.52.0.dev0"
__version__ = "4.53.0.dev0"
from pathlib import Path
from typing import TYPE_CHECKING
@ -445,6 +445,7 @@ else:
_import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
_import_structure["masking_utils"] = ["AttentionMaskInterface"]
_import_structure["optimization"] = [
"Adafactor",
"get_constant_schedule",
@ -914,6 +915,7 @@ if TYPE_CHECKING:
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)
from .masking_utils import AttentionMaskInterface
from .model_debugging_utils import (
model_addition_debugger_context,
)

View File

@ -21,6 +21,104 @@ if is_hqq_available():
logger = logging.get_logger(__name__)
# Utility functions for static/sliding cache update logic
def _static_cache_update(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_position: Optional[torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the static cache tensors in place.
Args:
k_cache (`torch.Tensor`): The key cache tensor to update.
v_cache (`torch.Tensor`): The value cache tensor to update.
key_states (`torch.Tensor`): The new key states to add.
value_states (`torch.Tensor`): The new value states to add.
cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted.
If None, the entire cache is overwritten (prefill).
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place).
"""
if cache_position is None:
# Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
k_cache.copy_(key_states)
v_cache.copy_(value_states)
else:
# Generation phase. Update specific positions.
# Use index_copy_ for in-place update (compile-friendly).
try:
k_cache.index_copy_(2, cache_position, key_states)
v_cache.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
k_cache[:, :, cache_position] = key_states
v_cache[:, :, cache_position] = value_states
return k_cache, v_cache
def _sliding_cache_update(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_position: torch.LongTensor,
max_cache_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the sliding window cache tensors, returning the potentially modified tensors.
Args:
k_cache (`torch.Tensor`): The key cache tensor to update.
v_cache (`torch.Tensor`): The value cache tensor to update.
key_states (`torch.Tensor`): The new key states to add.
value_states (`torch.Tensor`): The new value states to add.
cache_position (`torch.LongTensor`): The position indices where the new states should be inserted.
max_cache_len (`int`): The maximum length of the sliding window cache.
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update.
For prefill > window, these are the full input states.
Otherwise, they are the updated cache tensors.
"""
# Handle prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > max_cache_len:
new_k = key_states[:, :, -max_cache_len:, :]
new_v = value_states[:, :, -max_cache_len:, :]
k_cache.copy_(new_k)
v_cache.copy_(new_v)
return key_states, value_states
# Sliding window logic for generation phase or prefill < window
slicing = torch.arange(max_cache_len, device=value_states.device)
current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
to_shift = current_seq_len > max_cache_len
indices = (slicing + to_shift.sum()) % max_cache_len
k_out_shifted = k_cache[:, :, indices]
v_out_shifted = v_cache[:, :, indices]
# Clamp cache_position to determine the *target index* within the shifted cache view
update_position = cache_position.clamp(min=0, max=max_cache_len - 1)
try:
k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
except NotImplementedError:
# Fallback for MPS: clone and modify the clone
k_out_updated = k_out_shifted.clone()
v_out_updated = v_out_shifted.clone()
k_out_updated[:, :, update_position] = key_states
v_out_updated[:, :, update_position] = value_states
k_cache.copy_(k_out_updated)
v_cache.copy_(v_out_updated)
return k_out_updated, v_out_updated
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
@ -98,6 +196,18 @@ class Cache:
else:
return None
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
query_length = cache_position.shape[0]
past_seen_tokens = self.get_seq_length()
kv_length = query_length + past_seen_tokens
return kv_length, 0
@dataclass
class CacheConfig:
@ -986,8 +1096,6 @@ class SinkCache(Cache):
```
"""
is_sliding = True
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
@ -1264,28 +1372,16 @@ class StaticCache(Cache):
"""
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
else:
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
# operation, that avoids copies and uses less memory.
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
return _static_cache_update(
self.key_cache[layer_idx],
self.value_cache[layer_idx],
key_states,
value_states,
cache_kwargs.get("cache_position"),
)
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
@ -1304,6 +1400,16 @@ class StaticCache(Cache):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
kv_length = self.get_max_cache_shape()
return kv_length, 0
class SlidingWindowCache(StaticCache):
"""
@ -1314,7 +1420,7 @@ class SlidingWindowCache(StaticCache):
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
@ -1360,7 +1466,6 @@ class SlidingWindowCache(StaticCache):
```
"""
is_sliding = True
is_compileable = True
def __init__(
@ -1379,6 +1484,7 @@ class SlidingWindowCache(StaticCache):
"config and it's not set to None."
)
max_cache_len = min(config.sliding_window, max_cache_len)
self.sliding_window = config.sliding_window
super().__init__(
config=config,
max_batch_size=max_batch_size,
@ -1398,46 +1504,21 @@ class SlidingWindowCache(StaticCache):
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] >= self.max_cache_len:
k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.max_cache_len :, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
if cache_position is None:
raise ValueError("`cache_position` must be provided for SlidingWindowCache.")
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
to_shift = cache_position > self.max_cache_len - 1
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
return _sliding_cache_update(
self.key_cache[layer_idx],
self.value_cache[layer_idx],
key_states,
value_states,
cache_position,
self.max_cache_len,
)
def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len
@ -1448,6 +1529,21 @@ class SlidingWindowCache(StaticCache):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
# torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor
kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
kv_length = max(query_length, self.get_max_cache_shape())
return kv_length, kv_offset
class EncoderDecoderCache(Cache):
"""
@ -1680,12 +1776,13 @@ class HybridCache(Cache):
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"Setting `cache_implementation` to 'hybrid' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self._sliding_window_max_len = min(config.sliding_window, max_cache_len)
self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings
# Sliding layers can't be larger than the overall max cache len
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
@ -1694,22 +1791,22 @@ class HybridCache(Cache):
self._dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
# If the attribute does not exist in the config, fallback to a simple StaticCache
if hasattr(config, "layer_types"):
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
else:
self.is_sliding = [False] * config.num_hidden_layers
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
self.max_batch_size,
self.num_key_value_heads,
self._sliding_window_max_len,
self.head_dim,
)
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
self.sliding_window = min(config.sliding_window, max_cache_len)
device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
@ -1718,7 +1815,7 @@ class HybridCache(Cache):
layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
@ -1726,42 +1823,6 @@ class HybridCache(Cache):
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] >= max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
to_shift = cache_position > max_cache_len - 1
cache_position = cache_position.clamp(0, max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def update(
self,
key_states: torch.Tensor,
@ -1772,7 +1833,10 @@ class HybridCache(Cache):
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
if cache_position is None:
raise ValueError("`cache_position` must be provided for HybridCache.")
is_sliding_layer = self.is_sliding[layer_idx]
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2)
@ -1781,25 +1845,22 @@ class HybridCache(Cache):
if self.value_cache[layer_idx].device != value_states.device:
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx]
key_states = key_states.to(k_cache.dtype)
value_states = value_states.to(v_cache.dtype)
if sliding_window:
update_fn = self._sliding_update
if is_sliding_layer:
return _sliding_cache_update(
k_cache,
v_cache,
key_states,
value_states,
cache_position,
k_cache.shape[2], # Use actual cache dim as max cache len
)
else:
update_fn = self._static_update
return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)
return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position)
def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len
@ -1822,6 +1883,26 @@ class HybridCache(Cache):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
if self.is_sliding[layer_idx]:
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
local_mask_kv_length = max(query_length, self.sliding_window)
return local_mask_kv_length, local_mask_kv_offset
full_mask_kv_offset = 0
full_mask_kv_length = self.get_max_cache_shape()
return full_mask_kv_length, full_mask_kv_offset
class HybridChunkedCache(Cache):
"""
@ -1891,11 +1972,11 @@ class HybridChunkedCache(Cache):
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self._dtype = dtype
if hasattr(config.get_text_config(), "no_rope_layers"):
self.is_sliding = config.no_rope_layers
# If the attribute does not exist in the config, fallback to a simple StaticCache
if hasattr(config, "layer_types"):
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
else:
layer_switch = getattr(config, "sliding_window_pattern", 2)
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
self.is_sliding = [False] * config.num_hidden_layers
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
@ -1978,11 +2059,7 @@ class HybridChunkedCache(Cache):
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if self.is_sliding[layer_idx]:
update_fn = self._sliding_update
else:
update_fn = self._static_update
update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update
return update_fn(
cache_position,
layer_idx,
@ -2017,6 +2094,37 @@ class HybridChunkedCache(Cache):
self.value_cache[layer_idx].zero_()
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
for each layer.
"""
if self.is_sliding[layer_idx]:
query_length = cache_position.shape[0]
first_cache_position = cache_position[0]
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
# This is the true general case for any Cache using local attention (sliding or chunked)
if first_cache_position >= self.sliding_window:
# Here the Cache is already full
local_mask_kv_length = self.sliding_window + query_length - 1
elif (
first_cache_position < self.sliding_window
and first_cache_position + query_length > self.sliding_window
):
# Here the Cache becomes full with the new input
local_mask_kv_length = first_cache_position + query_length
else:
# Here the Cache is still smaller than the local size, but we return the local size as it's static
local_mask_kv_length = self.sliding_window
return local_mask_kv_length, local_mask_kv_offset
full_mask_kv_offset = 0
full_mask_kv_length = self.get_max_cache_shape()
return full_mask_kv_length, full_mask_kv_offset
class OffloadedHybridCache(HybridChunkedCache):
def __init__(
@ -2033,7 +2141,7 @@ class OffloadedHybridCache(HybridChunkedCache):
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer
unique_devices = set(layer_device_map.values())
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
if len(unique_devices) > 1:
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
@ -2292,7 +2400,7 @@ class OffloadedStaticCache(StaticCache):
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer
unique_devices = set(layer_device_map.values())
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
if len(unique_devices) > 1:
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
@ -2369,6 +2477,9 @@ class OffloadedStaticCache(StaticCache):
A tuple containing the updated key and value states.
"""
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
if layer_idx == 0:
# Update seen tokens.
# TODO(gante): Remove this.

View File

@ -408,6 +408,10 @@ class PretrainedConfig(PushToHubMixin):
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# This attribute is important to know on load, but should not be serialized on save.
if "transformers_weights" in self:
delattr(self, "transformers_weights")
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
@ -1205,3 +1209,16 @@ if PretrainedConfig.push_to_hub.__doc__ is not None:
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
object="config", object_class="AutoConfig", object_files="configuration file"
)
ALLOWED_LAYER_TYPES = (
"full_attention",
"sliding_attention",
"chunked_attention",
)
def layer_type_validation(layer_types: list[str]):
"""Check that each entry in `layer_types` are allowed."""
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")

View File

@ -32,7 +32,7 @@ deps = {
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
"jieba": "jieba",
"jinja2": "jinja2>=3.1.0",
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5": "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
"kenlm": "kenlm",
"keras": "keras>2.9,<2.16",
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
"kernels": "kernels>=0.4.4,<0.5",

View File

@ -35,6 +35,7 @@ from ..utils import (
is_torch_available,
logging,
)
from ..utils.deprecation import deprecate_kwarg
if TYPE_CHECKING:
@ -279,10 +280,6 @@ class GenerationConfig(PushToHubMixin):
begin_suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
of index 123.
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. Check
@ -387,12 +384,6 @@ class GenerationConfig(PushToHubMixin):
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
specific criteria are met, including using a compilable cache. Please open an issue if you find the
need to use this flag.
> Wild card
generation_kwargs:
Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
present in `generate`'s signature will be used in the model forward pass.
"""
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
@ -448,7 +439,6 @@ class GenerationConfig(PushToHubMixin):
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.token_healing = kwargs.pop("token_healing", False)
self.guidance_scale = kwargs.pop("guidance_scale", None)
@ -493,8 +483,6 @@ class GenerationConfig(PushToHubMixin):
# Performance
self.compile_config = kwargs.pop("compile_config", None)
self.disable_compile = kwargs.pop("disable_compile", False)
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
# interface.
@ -514,7 +502,7 @@ class GenerationConfig(PushToHubMixin):
raise err
# Validate the values of the attributes
self.validate(is_init=True)
self.validate()
def __hash__(self):
return hash(self.to_json_string(ignore_metadata=True))
@ -576,9 +564,10 @@ class GenerationConfig(PushToHubMixin):
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.ASSISTED_GENERATION
else:
raise ValueError(
logger.warning(
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
"is only supported with Greedy Search and Sample."
"is only supported with Greedy Search and Sample. However, the base decoding mode (based on "
f"current flags) is {generation_mode} -- some of the set flags will be ignored."
)
# DoLa generation may extend some generation modes
@ -586,13 +575,15 @@ class GenerationConfig(PushToHubMixin):
if generation_mode in ("greedy_search", "sample"):
generation_mode = GenerationMode.DOLA_GENERATION
else:
raise ValueError(
logger.warning(
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
"is only supported with Greedy Search and Sample."
"is only supported with Greedy Search and Sample. However, the base decoding mode (based on "
f"current flags) is {generation_mode} -- some of the set flags will be ignored."
)
return generation_mode
def validate(self, is_init=False):
@deprecate_kwarg("is_init", version="4.54.0")
def validate(self, strict=False):
"""
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
@ -600,174 +591,24 @@ class GenerationConfig(PushToHubMixin):
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
other inputs and/or the model, such as parameters related to the generation length.
Arg:
is_init (`bool`, *optional*, defaults to `False`):
Whether the validation is performed during the initialization of the instance.
Args:
strict (bool): If True, raise an exception for any issues found. If False, only log issues.
"""
minor_issues = {} # format: {attribute_name: issue_description}
# Validation of individual attributes
# 1. Validation of individual attributes
# 1.1. Decoding attributes
if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
if self.pad_token_id is not None and self.pad_token_id < 0:
warnings.warn(
minor_issues["pad_token_id"] = (
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
"generating, if there is padding. Please set `pad_token_id` explicitly as "
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
)
# Validation of attribute relations:
fix_location = ""
if is_init:
fix_location = (
" This was detected when initializing the generation config instance, which means the corresponding "
"file may hold incorrect parameterization and should be fixed."
)
# 1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
+ fix_location
)
if self.temperature is not None and self.temperature != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
UserWarning,
)
if self.top_p is not None and self.top_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
UserWarning,
)
if self.min_p is not None:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
UserWarning,
)
if self.typical_p is not None and self.typical_p != 1.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
UserWarning,
)
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
UserWarning,
)
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
UserWarning,
)
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
warnings.warn(
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
UserWarning,
)
# 2. detect beam-only parameterization when not in beam mode
if self.num_beams is None:
warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning)
self.num_beams = 1
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
)
if self.early_stopping is not False:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
UserWarning,
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
),
UserWarning,
)
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
),
UserWarning,
)
if self.length_penalty is not None and self.length_penalty != 1.0:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
UserWarning,
)
if self.constraints is not None:
warnings.warn(
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
UserWarning,
)
# 3. detect incorrect parameterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = (
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
"`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
"`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
# group beam search
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if self.do_sample is True:
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
if self.num_beams % self.num_beam_groups != 0:
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
if self.diversity_penalty == 0.0:
raise ValueError(
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# DoLa generation
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
warnings.warn(
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
UserWarning,
)
# 4. check `num_return_sequences`
if self.num_return_sequences != 1:
if self.num_beams == 1:
if self.do_sample is False:
raise ValueError(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f"(got {self.num_return_sequences})."
)
elif self.num_return_sequences > self.num_beams:
raise ValueError(
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
f"({self.num_beams})."
)
# 5. check cache-related arguments
# 1.2. Cache attributes
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
raise ValueError(
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
@ -784,6 +625,141 @@ class GenerationConfig(PushToHubMixin):
if not isinstance(self.cache_config, cache_class):
self.cache_config = cache_class.from_dict(self.cache_config)
self.cache_config.validate()
# 1.3. Performance attributes
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
"instance of `CompileConfig`."
)
# 1.4. Watermarking attributes
if self.watermarking_config is not None:
if not (
isinstance(self.watermarking_config, WatermarkingConfig)
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
):
minor_issues["watermarking_config"] = (
"`watermarking_config` as a dict is deprecated and will be removed in v4.54.0. Please construct "
"`watermarking_config` object with `WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class."
)
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 2. Validation of attribute combinations
# 2.1. detect sampling-only parameterization when not in sampling mode
if self.do_sample is False:
greedy_wrong_parameter_msg = (
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
)
if self.temperature is not None and self.temperature != 1.0:
minor_issues["temperature"] = greedy_wrong_parameter_msg.format(
flag_name="temperature", flag_value=self.temperature
)
if self.top_p is not None and self.top_p != 1.0:
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
if self.min_p is not None:
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
if self.typical_p is not None and self.typical_p != 1.0:
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
flag_name="typical_p", flag_value=self.typical_p
)
if (
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
): # contrastive search uses top_k
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format(
flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
)
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format(
flag_name="eta_cutoff", flag_value=self.eta_cutoff
)
# 2.2. detect beam-only parameterization when not in beam mode
if self.num_beams == 1:
single_beam_wrong_parameter_msg = (
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
)
if self.early_stopping is not False:
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
flag_name="early_stopping", flag_value=self.early_stopping
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
minor_issues["num_beam_groups"] = single_beam_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
minor_issues["diversity_penalty"] = single_beam_wrong_parameter_msg.format(
flag_name="diversity_penalty", flag_value=self.diversity_penalty
)
if self.length_penalty is not None and self.length_penalty != 1.0:
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
flag_name="length_penalty", flag_value=self.length_penalty
)
if self.constraints is not None:
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
flag_name="constraints", flag_value=self.constraints
)
# DoLa generation needs num_beams == 1
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
minor_issues["repetition_penalty"] = (
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
"DoLa decoding is `repetition_penalty>=1.2`.",
)
# 2.3. detect incorrect parameterization specific to advanced beam modes
else:
# constrained beam search
if self.constraints is not None or self.force_words_ids is not None:
constrained_wrong_parameter_msg = (
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
)
if self.do_sample is True:
raise ValueError(
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
)
if self.num_beam_groups is not None and self.num_beam_groups != 1:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
# group beam search
elif self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
"this generation mode, "
)
if self.do_sample is True:
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
if self.num_beams % self.num_beam_groups != 0:
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
if self.diversity_penalty == 0.0:
raise ValueError(
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# 2.4. check `num_return_sequences`
if self.num_return_sequences != 1:
if self.num_beams == 1:
if self.do_sample is False:
raise ValueError(
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
f"(got {self.num_return_sequences})."
)
elif self.num_return_sequences > self.num_beams:
raise ValueError(
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
f"({self.num_beams})."
)
# 2.5. check cache-related arguments
if self.use_cache is False:
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
@ -794,42 +770,20 @@ class GenerationConfig(PushToHubMixin):
)
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
if getattr(self, arg_name) is not None:
logger.warning_once(
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name))
minor_issues[arg_name] = no_cache_warning.format(
cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)
)
# 6. check watermarking arguments
if self.watermarking_config is not None:
if not (
isinstance(self.watermarking_config, WatermarkingConfig)
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
):
warnings.warn(
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
FutureWarning,
)
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 7. performances arguments
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
raise ValueError(
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
"instance of `CompileConfig`."
)
# 8. other incorrect combinations
# 2.6. other incorrect combinations
if self.return_dict_in_generate is not True:
for extra_output_flag in self.extra_output_flags:
if getattr(self, extra_output_flag) is True:
warnings.warn(
minor_issues[extra_output_flag] = (
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.",
UserWarning,
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored."
)
# 8. check common issue: passing `generate` arguments inside the generation config
# 3. Check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
"stopping_criteria",
@ -839,6 +793,7 @@ class GenerationConfig(PushToHubMixin):
"streamer",
"negative_prompt_ids",
"negative_prompt_attention_mask",
"use_model_defaults",
)
for arg in generate_arguments:
if hasattr(self, arg):
@ -847,6 +802,30 @@ class GenerationConfig(PushToHubMixin):
"`generate()` (or a pipeline) directly."
)
# Finally, handle caught minor issues. With default parameterization, we will throw a minimal warning.
if len(minor_issues) > 0:
# Full list of issues with potential fixes
info_message = []
for attribute_name, issue_description in minor_issues.items():
info_message.append(f"- `{attribute_name}`: {issue_description}")
info_message = "\n".join(info_message)
info_message += (
"\nIf you're using a pretrained model, note that some of these attributes may be set through the "
"model's `generation_config.json` file."
)
if strict:
raise ValueError("GenerationConfig is invalid: \n" + info_message)
else:
attributes_with_issues = list(minor_issues.keys())
warning_message = (
f"The following generation flags are not valid and may be ignored: {attributes_with_issues}."
)
if logger.getEffectiveLevel() >= logging.WARNING:
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
logger.warning(warning_message)
logger.info(info_message)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@ -871,18 +850,13 @@ class GenerationConfig(PushToHubMixin):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance.
# At save time, validate the instance enforcing strictness -- if any warning/exception would be thrown, we
# refuse to save the instance.
# This strictness is enforced to prevent bad configurations from being saved and re-used.
try:
with warnings.catch_warnings(record=True) as caught_warnings:
self.validate()
if len(caught_warnings) > 0:
raise ValueError(str([w.message for w in caught_warnings]))
self.validate(strict=True)
except ValueError as exc:
raise ValueError(
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
)
raise ValueError(str(exc) + "\n\nFix these issues to save the configuration.")
use_auth_token = kwargs.pop("use_auth_token", None)

View File

@ -15,7 +15,7 @@
import inspect
import math
from typing import Callable, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
@ -25,6 +25,10 @@ from ..utils import add_start_docstrings
from ..utils.logging import get_logger
# TODO (joao): We shouldn't need this, but there would be a circular import
if TYPE_CHECKING:
from ..generation.configuration_utils import GenerationConfig
logger = get_logger(__name__)
@ -1906,8 +1910,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
begin_index (`int`):
Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*):
Whether timestamps can be predicted from logprobs over all timestamps.
Examples:
``` python
@ -1940,8 +1946,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
def __init__(
self,
generate_config,
begin_index: Optional[int] = None,
generate_config: "GenerationConfig",
begin_index: int,
_detect_timestamp_from_logprob: Optional[bool] = None,
): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
@ -1954,11 +1960,13 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
num_forced_ids = (
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
)
self.begin_index = begin_index or (num_forced_ids + 1)
self.begin_index = begin_index
if begin_index is None:
raise ValueError(
"`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` "
"must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` "
"was `len(generate_config.forced_decoder_ids)`"
)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50

View File

@ -46,6 +46,7 @@ from ..dynamic_module_utils import (
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module
from ..masking_utils import create_masks_for_generate
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
@ -74,6 +75,7 @@ from .candidate_generator import (
from .configuration_utils import (
NEED_SETUP_CACHE_CLASSES_MAPPING,
QUANT_BACKEND_CLASSES_MAPPING,
CompileConfig,
GenerationConfig,
GenerationMode,
)
@ -649,12 +651,22 @@ class GenerationMixin:
causal_mask_creation_function = getattr(
decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
)
# If it's not defined, it means the model uses the new general mask API
if causal_mask_creation_function is None: # can't be found
logger.warning_once(
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
"writing code, see Llama for an example implementation. If you're a user, please report this "
"issue on GitHub."
output_attentions = kwargs.get("output_attentions", False)
token_type_ids = getattr(model_input, "token_type_ids", None)
# Some models may overwrite the general one
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
attention_mask = causal_mask_creation_function(
config=self.config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_attentions=output_attentions,
token_type_ids=token_type_ids,
)
else:
attention_mask = causal_mask_creation_function(
@ -1246,12 +1258,6 @@ class GenerationMixin:
device=device,
)
)
if generation_config.forced_decoder_ids is not None:
# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
raise ValueError(
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
"in favour of `input_ids` or `decoder_input_ids` respectively.",
)
# TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor)
@ -1752,16 +1758,21 @@ class GenerationMixin:
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
):
modified_values = {}
default_generation_config = GenerationConfig()
for key, default_value in default_generation_config.__dict__.items():
global_default_generation_config = GenerationConfig()
model_generation_config = self.generation_config
# we iterate over the model's generation config: it may hold custom keys, which we'll want to copy
for key, model_gen_config_value in model_generation_config.__dict__.items():
if key.startswith("_") or key == "transformers_version": # metadata
continue
custom_gen_config_value = getattr(generation_config, key)
model_gen_config_value = getattr(self.generation_config, key)
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
global_default_value = getattr(global_default_generation_config, key, None)
custom_gen_config_value = getattr(generation_config, key, None)
if (
custom_gen_config_value == global_default_value
and model_gen_config_value != global_default_value
):
modified_values[key] = model_gen_config_value
setattr(generation_config, key, model_gen_config_value)
if len(modified_values) > 0:
if use_model_defaults is None and len(modified_values) > 0:
logger.warning_once(
f"`generation_config` default values have been modified to match model-specific defaults: "
f"{modified_values}. If this is not desired, please set these values explicitly."
@ -1980,7 +1991,9 @@ class GenerationMixin:
instantiated, writes it to `model_kwargs`, under the name expected by the model.
"""
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
cache_name = "past_key_values" if not is_hybrid_cache else "cache_params"
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
@ -3532,6 +3545,19 @@ class GenerationMixin:
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2" and getattr(
model_kwargs.get("past_key_values"), "is_compileable", False
):
if generation_config.compile_config is None:
generation_config.compile_config = CompileConfig(fullgraph=False)
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
elif generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config)
if generation_config.prefill_chunk_size is not None:

View File

@ -142,7 +142,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["tensor_parallel"] = [
"shard_and_distribute_module",
"SUPPORTED_TP_STYLES",
"ALL_PARALLEL_STYLES",
"translate_to_torch_parallel_style",
]
try:
@ -271,7 +271,7 @@ if TYPE_CHECKING:
pass
else:
from .tensor_parallel import (
SUPPORTED_TP_STYLES,
ALL_PARALLEL_STYLES,
shard_and_distribute_module,
translate_to_torch_parallel_style,
)

Some files were not shown because too many files have changed in this diff Show More