mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge branch 'main' into add-aimv2-model
This commit is contained in:
commit
1a11e86dac
51
.github/workflows/check_failed_model_tests.yml
vendored
51
.github/workflows/check_failed_model_tests.yml
vendored
@ -39,55 +39,100 @@ jobs:
|
||||
name: ci_results_run_models_gpu
|
||||
path: /transformers/ci_results_run_models_gpu
|
||||
|
||||
- name: Check file
|
||||
working-directory: /transformers
|
||||
run: |
|
||||
if [ -f ci_results_run_models_gpu/new_model_failures.json ]; then
|
||||
echo "`ci_results_run_models_gpu/new_model_failures.json` exists, continue ..."
|
||||
echo "process=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "`ci_results_run_models_gpu/new_model_failures.json` doesn't exist, abort."
|
||||
echo "process=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- uses: actions/download-artifact@v4
|
||||
if: ${{ env.process == 'true' }}
|
||||
with:
|
||||
pattern: setup_values*
|
||||
path: setup_values
|
||||
merge-multiple: true
|
||||
|
||||
- name: Prepare some setup values
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
if [ -f setup_values/prev_workflow_run_id.txt ]; then
|
||||
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
if [ -f setup_values/other_workflow_run_id.txt ]; then
|
||||
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Update clone
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: git fetch && git checkout ${{ github.sha }}
|
||||
|
||||
- name: Get target commit
|
||||
working-directory: /transformers/utils
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"]); print(commit)')" >> $GITHUB_ENV
|
||||
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"], workflow_run_id=os.environ["PREV_WORKFLOW_RUN_ID"]); print(commit)')" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout to `start_sha`
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: git fetch && git checkout ${{ inputs.start_sha }}
|
||||
|
||||
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Environment
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
python3 utils/print_env.py
|
||||
|
||||
- name: Show installed libraries and their versions
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: pip freeze
|
||||
|
||||
- name: Check failed tests
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: python3 utils/check_bad_commit.py --start_commit ${{ inputs.start_sha }} --end_commit ${{ env.END_SHA }} --file ci_results_run_models_gpu/new_model_failures.json --output_file new_model_failures_with_bad_commit.json
|
||||
|
||||
- name: Show results
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
ls -l new_model_failures_with_bad_commit.json
|
||||
cat new_model_failures_with_bad_commit.json
|
||||
|
||||
- name: Checkout back
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
run: |
|
||||
git checkout ${{ inputs.start_sha }}
|
||||
|
||||
- name: Process report
|
||||
shell: bash
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
env:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
|
||||
run: |
|
||||
python3 utils/process_bad_commit_report.py
|
||||
@ -95,7 +140,9 @@ jobs:
|
||||
- name: Process report
|
||||
shell: bash
|
||||
working-directory: /transformers
|
||||
if: ${{ env.process == 'true' }}
|
||||
env:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
|
||||
run: |
|
||||
{
|
||||
@ -105,7 +152,7 @@ jobs:
|
||||
} >> "$GITHUB_ENV"
|
||||
|
||||
- name: Send processed report
|
||||
if: ${{ !endsWith(env.REPORT_TEXT, '{}') }}
|
||||
if: ${{ env.process == 'true' && !endsWith(env.REPORT_TEXT, '{}') }}
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
|
35
.github/workflows/self-scheduled-caller.yml
vendored
35
.github/workflows/self-scheduled-caller.yml
vendored
@ -8,8 +8,43 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- run_scheduled_ci*
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
prev_workflow_run_id:
|
||||
description: 'previous workflow run id to compare'
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
other_workflow_run_id:
|
||||
description: 'other workflow run id to compare'
|
||||
type: string
|
||||
required: false
|
||||
default: ""
|
||||
|
||||
|
||||
# Used for `push` to easily modiffy the target workflow runs to compare against
|
||||
env:
|
||||
prev_workflow_run_id: ""
|
||||
other_workflow_run_id: ""
|
||||
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
name: Setup
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Setup
|
||||
run: |
|
||||
mkdir "setup_values"
|
||||
echo "${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}" > "setup_values/prev_workflow_run_id.txt"
|
||||
echo "${{ inputs.other_workflow_run_id || env.other_workflow_run_id }}" > "setup_values/other_workflow_run_id.txt"
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: setup_values
|
||||
path: setup_values
|
||||
|
||||
model-ci:
|
||||
name: Model CI
|
||||
uses: ./.github/workflows/self-scheduled.yml
|
||||
|
18
.github/workflows/slack-report.yml
vendored
18
.github/workflows/slack-report.yml
vendored
@ -39,6 +39,21 @@ jobs:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/download-artifact@v4
|
||||
|
||||
- name: Prepare some setup values
|
||||
run: |
|
||||
if [ -f setup_values/prev_workflow_run_id.txt ]; then
|
||||
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
if [ -f setup_values/other_workflow_run_id.txt ]; then
|
||||
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
|
||||
else
|
||||
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Send message to Slack
|
||||
if: ${{ inputs.job != 'run_quantization_torch_gpu' }}
|
||||
env:
|
||||
@ -50,7 +65,6 @@ jobs:
|
||||
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
|
||||
CI_EVENT: ${{ inputs.ci_event }}
|
||||
CI_SHA: ${{ github.sha }}
|
||||
CI_WORKFLOW_REF: ${{ github.workflow_ref }}
|
||||
CI_TEST_JOB: ${{ inputs.job }}
|
||||
SETUP_STATUS: ${{ inputs.setup_status }}
|
||||
# We pass `needs.setup.outputs.matrix` as the argument. A processing in `notification_service.py` to change
|
||||
@ -58,7 +72,6 @@ jobs:
|
||||
# For a job that doesn't depend on (i.e. `needs`) `setup`, the value for `inputs.folder_slices` would be an
|
||||
# empty string, and the called script still get one argument (which is the emtpy string).
|
||||
run: |
|
||||
sudo apt-get install -y curl
|
||||
pip install huggingface_hub
|
||||
pip install slack_sdk
|
||||
pip show slack_sdk
|
||||
@ -86,7 +99,6 @@ jobs:
|
||||
# We pass `needs.setup.outputs.quantization_matrix` as the argument. A processing in `notification_service_quantization.py` to change
|
||||
# `quantization/bnb` to `quantization_bnb` is required, as the artifact names use `_` instead of `/`.
|
||||
run: |
|
||||
sudo apt-get install -y curl
|
||||
pip install huggingface_hub
|
||||
pip install slack_sdk
|
||||
pip show slack_sdk
|
||||
|
@ -71,6 +71,9 @@ RUN python3 -m pip install --no-cache-dir g2p-en
|
||||
# For Some bitsandbytes tests
|
||||
RUN python3 -m pip install --no-cache-dir einops
|
||||
|
||||
# For Some tests with `@require_liger_kernel`
|
||||
RUN python3 -m pip install --no-cache-dir liger-kernel
|
||||
|
||||
# `kernels` may give different outputs (within 1e-5 range) even with the same model (weights) and the same inputs
|
||||
RUN python3 -m pip uninstall -y kernels
|
||||
|
||||
|
@ -455,6 +455,8 @@
|
||||
title: Falcon
|
||||
- local: model_doc/falcon3
|
||||
title: Falcon3
|
||||
- local: model_doc/falcon_h1
|
||||
title: FalconH1
|
||||
- local: model_doc/falcon_mamba
|
||||
title: FalconMamba
|
||||
- local: model_doc/flan-t5
|
||||
|
@ -125,4 +125,44 @@ would expect from a usual Python dictionary:
|
||||
|
||||
# You can also globally `register` a new function directly on it
|
||||
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
|
||||
```
|
||||
```
|
||||
|
||||
## Attention Mask Interface
|
||||
|
||||
Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
|
||||
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
|
||||
the `AttentionInterface`:
|
||||
|
||||
```python
|
||||
from transformers import AttentionMaskInterface
|
||||
from transformers.masking_utils import sdpa_mask
|
||||
import torch
|
||||
|
||||
def my_new_sdpa_mask(*args, **kwargs):
|
||||
print("I just entered the attention mask computation")
|
||||
return sdpa_mask(*args, **kwargs)
|
||||
|
||||
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
|
||||
```
|
||||
|
||||
The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
|
||||
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
|
||||
and `attention_mask=None` will be passed along to the Attention layers.
|
||||
|
||||
The default signature of the attention mask functions is the following:
|
||||
|
||||
```python
|
||||
def custom_attention_mask(
|
||||
batch_size: int, # required arg
|
||||
cache_position: torch.Tensor, # required arg
|
||||
kv_length: int, # required arg
|
||||
kv_offset: int = 0, # required arg
|
||||
mask_function: Callable = causal_mask_function, # required arg
|
||||
attention_mask: Optional[torch.Tensor] = None, # required arg
|
||||
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
|
||||
) -> Optional[torch.Tensor]:
|
||||
```
|
||||
|
||||
It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.
|
||||
|
||||
If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).
|
@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
|
||||
|
||||
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn_weights = self.add_decomposed_rel_pos(
|
||||
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
||||
)
|
||||
|
||||
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
|
||||
@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`].
|
||||
|
||||
```py
|
||||
from transformers import SamModel
|
||||
from transformers.models.sam import modeling_sam
|
||||
|
||||
# replace the attention class in the modeling_sam module
|
||||
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
|
||||
|
||||
# load the pretrained SAM model
|
||||
model = SamModel.from_pretrained("facebook/sam-vit-base")
|
||||
|
||||
# replace the attention class in the vision_encoder module
|
||||
for layer in model.vision_encoder.layers:
|
||||
if hasattr(layer, "attn"):
|
||||
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
|
||||
```
|
||||
|
||||
## LoRA
|
||||
@ -138,7 +134,7 @@ config = LoraConfig(
|
||||
# apply LoRA to q and v
|
||||
target_modules=["q", "v"],
|
||||
lora_dropout=0.1,
|
||||
task_type="mask-generation"
|
||||
task_type="FEATURE_EXTRACTION"
|
||||
)
|
||||
```
|
||||
|
||||
@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer
|
||||
|
||||
```py
|
||||
model.print_trainable_parameters()
|
||||
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"
|
||||
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"
|
||||
```
|
@ -29,6 +29,11 @@ Most of those are only useful if you are studying the code of the models in the
|
||||
[[autodoc]] AttentionInterface
|
||||
- register
|
||||
|
||||
## Attention Mask Functions
|
||||
|
||||
[[autodoc]] AttentionMaskInterface
|
||||
- register
|
||||
|
||||
## Rotary Position Embedding Functions
|
||||
|
||||
[[autodoc]] dynamic_rope_update
|
||||
|
@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
|
||||
<!---
|
||||
## Usage Tips
|
||||
|
||||
Tips:
|
||||
Tips:
|
||||
|
||||
- The architecture is based on Mamba-2 models.
|
||||
|
||||
@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
|
||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||
```
|
||||
|
||||
|
||||
## Padding-Free Training
|
||||
|
||||
Bamba supports padding-free training in which distinct training examples can be concatenated
|
||||
together while nevertheless processing the inputs as though they belonged to separate batches. When
|
||||
the examples are of varying lengths, padding-free training can provide significant speed ups and
|
||||
memory savings compared to batching the examples together and using padding, as the unnecessary
|
||||
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
|
||||
as the model and the data distribution, but throughput gains up to [~2x are commonly
|
||||
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
|
||||
|
||||
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
|
||||
packages, and the following arguments must be passed to the model in addition to `input_ids` and
|
||||
`labels`:
|
||||
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
|
||||
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
|
||||
* Each of the [`FlashAttentionKwargs`]
|
||||
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
|
||||
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
|
||||
* `max_length_q: int`: the longest query length in the batch.
|
||||
* `max_length_k: int`: the longest key length in the batch.
|
||||
|
||||
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
|
||||
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
|
||||
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
|
||||
for additional information.
|
||||
|
||||
|
||||
[[autodoc]] BambaForCausalLM
|
||||
- forward
|
||||
|
||||
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
|
||||
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
|
||||
|
65
docs/source/en/model_doc/falcon_h1.md
Normal file
65
docs/source/en/model_doc/falcon_h1.md
Normal file
@ -0,0 +1,65 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# FalconH1
|
||||
|
||||
## Overview
|
||||
|
||||
The FalconH1 model was developed by the TII Pretraining team. A comprehensive research paper covering the architecture, pretraining dynamics, experimental results, and conclusions is forthcoming. You can read more about this series in [this website](https://github.com/tiiuae/Falcon-H1).
|
||||
|
||||
## Contributors
|
||||
|
||||
This model was contributed by [DhiyaEddine](https://huggingface.co/DhiyaEddine), [ybelkada](https://huggingface.co/ybelkada), [JingweiZuo](https://huggingface.co/JingweiZuo), [IlyasChahed](https://huggingface.co/IChahed), and [MaksimVelikanov](https://huggingface.co/yellowvm).
|
||||
The original code can be found [here](https://github.com/tiiuae/Falcon-H1).
|
||||
|
||||
|
||||
## FalconH1Config
|
||||
|
||||
| Model | Depth | Dim | Attn Heads | KV | Mamba Heads | d_head | d_state | Ctx Len |
|
||||
|-----------|--------|------|------------|----|--------------|--------------|------|-----------------|
|
||||
| H1 0.5B | 36 | 1024 | 8 | 2 | 24 | 64 / 64 | 128 | 4K, 16K-SFT |
|
||||
| H1 1.5B | 24 | 2048 | 8 | 2 | 48 | 128 / 64 | 256 | 128K |
|
||||
| H1 1.5B-d | 66 | 1280 | 6 | 2 | 24 | 128 / 64 | 256 | 128K |
|
||||
| H1 3B | 32 | 2560 | 10 | 2 | 32 | 128 / 128 | 256 | 128K |
|
||||
| H1 7B | 44 | 3072 | 12 | 2 | 24 | 128 / 128 | 256 | 256K |
|
||||
| H1 34B | 72 | 5120 | 20 | 4 | 32 | 128 / 128 | 256 | 256K |
|
||||
|
||||
|
||||
|
||||
[[autodoc]] FalconH1Config
|
||||
|
||||
<!---
|
||||
## Usage Tips
|
||||
Tips:
|
||||
- The architecture is based on Mamba-2 models.
|
||||
## FalconH1Model
|
||||
[[autodoc]] FalconH1Model
|
||||
- forward
|
||||
-->
|
||||
|
||||
## FalconH1ForCausalLM
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
|
||||
|
||||
message = ["Mamba is a snake with following properties "]
|
||||
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
|
||||
response = model.generate(**inputs, max_new_tokens=64)
|
||||
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
|
||||
```
|
||||
|
||||
[[autodoc]] FalconH1ForCausalLM
|
||||
- forward
|
||||
|
||||
This HF implementation is contributed by [younesbelkada](https://github.com/younesbelkada) and [DhiaEddineRhaiem](https://github.com/dhiaEddineRhaiem).
|
@ -147,7 +147,7 @@ print(processor.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
### Multi image inference
|
||||
|
||||
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. Here is how you can do it:
|
||||
LLaVa-OneVision can perform inference with multiple images as input, where images either belong to the same prompt or different prompts (in batched inference). For that you have to use checkpoints with an "ov" suffix. For multi-image cases, we recommend using a **nested list of images** as input. Otherwise, every image will be patchified and consume a lot of memory. Here is how you can do it:
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
@ -14,85 +14,124 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Mamba
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Mamba
|
||||
|
||||
The Mamba model was proposed in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
|
||||
[Mamba](https://huggingface.co/papers/2312.00752) is a selective structured state space model (SSMs) designed to work around Transformers computational inefficiency when dealing with long sequences. It is a completely attention-free architecture, and comprised of a combination of H3 and gated MLP blocks (Mamba block). Mamba's "content-based reasoning" allows it to focus on specific parts of an input depending on the current token. Mamba also uses a new hardware-aware parallel algorithm to compensate for the lack of convolutional operations. As a result, Mamba has fast inference and can scale to very long sequences.
|
||||
|
||||
This model is a new paradigm architecture based on `state-space-models`. You can read more about the intuition behind these [here](https://srush.github.io/annotated-s4/).
|
||||
You can find all the original Mamba checkpoints under the [State Space Models](https://huggingface.co/state-spaces) organization.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers' computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5× higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.*
|
||||
> [!TIP]
|
||||
> Click on the Mamba models in the right sidebar for more examples of how to apply Mamba to different language tasks.
|
||||
|
||||
Tips:
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
- Mamba is a new `state space model` architecture that rivals the classic Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
|
||||
- Mamba stacks `mixer` layers, which are the equivalent of `Attention` layers. The core logic of `mamba` is held in the `MambaMixer` class.
|
||||
- Two implementations cohabit: one is optimized and uses fast cuda kernels, while the other one is naive but can run on any device!
|
||||
- The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the [`mamba-ssm`](https://github.com/state-spaces/mamba) and the [`causal_conv1d`](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports them!
|
||||
- Contributions to make the naive path faster are welcome 🤗
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ).
|
||||
The original code can be found [here](https://github.com/state-spaces/mamba).
|
||||
|
||||
# Usage
|
||||
|
||||
### A simple generation example:
|
||||
```python
|
||||
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="text-generation",
|
||||
model="state-spaces/mamba-130m-hf",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create energy through a process known as")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
|
||||
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
|
||||
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16, device_map="auto",)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
out = model.generate(input_ids, max_new_tokens=10)
|
||||
print(tokenizer.batch_decode(out))
|
||||
output = model.generate(**input_ids)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
```
|
||||
|
||||
### Peft finetuning
|
||||
The slow version is not very stable for training, and the fast one needs `float32`!
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
||||
model_id = "state-spaces/mamba-130m-hf"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
logging_dir='./logs',
|
||||
logging_steps=10,
|
||||
learning_rate=2e-3
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
bias="none"
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="quote",
|
||||
)
|
||||
trainer.train()
|
||||
```bash
|
||||
echo -e "Plants create energy through a process known as" | transformers run --task text-generation --model state-spaces/mamba-130m-hf --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [torchao](../quantization/torchao) to only quantize the weights to 4-bit integers.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
quantization_config = Int4WeightOnlyConfig(group_size=128)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
|
||||
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf", torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto",)
|
||||
input_ids = tokenizer("Plants create energy through a process known as", return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
## Notes
|
||||
|
||||
- The current implementation uses the original CUDA kernels. The FlashAttention equivalent implementation is hosted in the [mamba-ssm](https://github.com/state-spaces/mamba) and [causal_conv1d](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports it!
|
||||
- Mamba stacks `mixer` layers which are equivalent to `Attention` layers. You can find the main logic of Mamba in the `MambaMixer` class.
|
||||
- The example below demonstrates how to fine-tune Mamba with [PEFT](https://huggingface.co/docs/peft).
|
||||
|
||||
```py
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model_id = "state-spaces/mamba-130m-hf"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("Abirate/english_quotes", split="train")
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=4,
|
||||
logging_dir='./logs',
|
||||
logging_steps=10,
|
||||
learning_rate=2e-3
|
||||
)
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
bias="none"
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
peft_config=lora_config,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="quote",
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## MambaConfig
|
||||
|
||||
[[autodoc]] MambaConfig
|
||||
|
@ -43,8 +43,8 @@ import requests
|
||||
from transformers import SamHQModel, SamHQProcessor
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
|
||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
||||
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
|
||||
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
@ -69,8 +69,8 @@ import requests
|
||||
from transformers import SamHQModel, SamHQProcessor
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
|
||||
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
||||
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
|
||||
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
|
||||
|
||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
|
@ -14,59 +14,77 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Swin Transformer
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# Swin Transformer
|
||||
|
||||
The Swin Transformer was proposed in [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)
|
||||
by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
|
||||
[Swin Transformer](https://huggingface.co/papers/2103.14030) is a hierarchical vision transformer. Images are processed in patches and windowed self-attention is used to capture local information. These windows are shifted across the image to allow for cross-window connections, capturing global information more efficiently. This hierarchical approach with shifted windows allows the Swin Transformer to process images effectively at different scales and achieve linear computational complexity relative to image size, making it a versatile backbone for various vision tasks like image classification and object detection.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all official Swin Transformer checkpoints under the [Microsoft](https://huggingface.co/microsoft?search_models=swin) organization.
|
||||
|
||||
*This paper presents a new vision Transformer, called Swin Transformer, that capably serves as a general-purpose backbone
|
||||
for computer vision. Challenges in adapting Transformer from language to vision arise from differences between the two domains,
|
||||
such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text.
|
||||
To address these differences, we propose a hierarchical Transformer whose representation is computed with \bold{S}hifted
|
||||
\bold{win}dows. The shifted windowing scheme brings greater efficiency by limiting self-attention computation to non-overlapping
|
||||
local windows while also allowing for cross-window connection. This hierarchical architecture has the flexibility to model at
|
||||
various scales and has linear computational complexity with respect to image size. These qualities of Swin Transformer make it
|
||||
compatible with a broad range of vision tasks, including image classification (87.3 top-1 accuracy on ImageNet-1K) and dense
|
||||
prediction tasks such as object detection (58.7 box AP and 51.1 mask AP on COCO test-dev) and semantic segmentation
|
||||
(53.5 mIoU on ADE20K val). Its performance surpasses the previous state-of-the-art by a large margin of +2.7 box AP and
|
||||
+2.6 mask AP on COCO, and +3.2 mIoU on ADE20K, demonstrating the potential of Transformer-based models as vision backbones.
|
||||
The hierarchical design and the shifted window approach also prove beneficial for all-MLP architectures.*
|
||||
> [!TIP]
|
||||
> Click on the Swin Transformer models in the right sidebar for more examples of how to apply Swin Transformer to different image tasks.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
The example below demonstrates how to classify an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
<small> Swin Transformer architecture. Taken from the <a href="https://arxiv.org/abs/2102.03334">original paper</a>.</small>
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
This model was contributed by [novice03](https://huggingface.co/novice03). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/microsoft/Swin-Transformer).
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
## Usage tips
|
||||
pipeline = pipeline(
|
||||
task="image-classification",
|
||||
model="microsoft/swin-tiny-patch4-window7-224",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline(images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg")
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
- Swin pads the inputs supporting any input height and width (if divisible by `32`).
|
||||
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
## Resources
|
||||
```py
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
||||
|
||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Swin Transformer.
|
||||
image_processor = AutoImageProcessor.from_pretrained(
|
||||
"microsoft/swin-tiny-patch4-window7-224",
|
||||
use_fast=True,
|
||||
)
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
"microsoft/swin-tiny-patch4-window7-224",
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
<PipelineTag pipeline="image-classification"/>
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = image_processor(image, return_tensors="pt").to("cuda")
|
||||
|
||||
- [`SwinForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
|
||||
- See also: [Image classification task guide](../tasks/image_classification)
|
||||
with torch.no_grad():
|
||||
logits = model(**inputs).logits
|
||||
predicted_class_id = logits.argmax(dim=-1).item()
|
||||
|
||||
Besides that:
|
||||
class_labels = model.config.id2label
|
||||
predicted_class_label = class_labels[predicted_class_id]
|
||||
print(f"The predicted class label is: {predicted_class_label}")
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
- [`SwinForMaskedImageModeling`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
|
||||
## Notes
|
||||
|
||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||
- Swin can pad the inputs for any input height and width divisible by `32`.
|
||||
- Swin can be used as a [backbone](../backbones). When `output_hidden_states = True`, it outputs both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
|
||||
|
||||
## SwinConfig
|
||||
|
||||
|
@ -95,7 +95,7 @@ transcription[0]
|
||||
|
||||
## Notes
|
||||
|
||||
- Whisper relies on [`~GenerationMixin.generate`] for inference.
|
||||
- Whisper relies a custom [`generate`] for inference, make sure to check the docs below.
|
||||
- The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text.
|
||||
|
||||
## WhisperConfig
|
||||
|
@ -54,8 +54,8 @@ For each model type, there is a separate class for each machine learning framewo
|
||||
from transformers import AutoModelForCausalLM, MistralForCausalLM
|
||||
|
||||
# load with AutoClass or model-specific class
|
||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
|
||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
|
||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
|
||||
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@ -272,6 +272,7 @@ Explicitly set the [torch_dtype](https://pytorch.org/docs/stable/tensor_attribut
|
||||
<hfoption id="specific dtype">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)
|
||||
|
@ -13,9 +13,15 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Distributed GPU inference
|
||||
# Tensor parallelism in transformers
|
||||
|
||||
[Tensor parallelism](./perf_train_gpu_many#tensor-parallelism) shards a model onto multiple GPUs and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each GPU can process a tensor slice.
|
||||
This document assumes that you are already familiar with the basics of tensor parallelism. If you are not, please refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) section on tensor parallelism.
|
||||
|
||||
> [!TIP]
|
||||
> Tensor parallelism is very communication intensive, therefore it is reccomended to use it on a single machine with multiple GPUs, utilizing fast intra-node communication. For multi-node training, methods as pipeline or data parallelism are more efficient (depending on your use case).
|
||||
|
||||
Tensor parallelism requires slight changes to the model parameters, therefore in transformers, we support some of the popular models out of the box.
|
||||
|
||||
> [!TIP]
|
||||
> Expand the list below to see which models support tensor parallelism. Open a GitHub issue or pull request to add support for a model not currently below.
|
||||
@ -37,9 +43,218 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
</details>
|
||||
|
||||
Set `tp_plan="auto"` in [`~AutoModel.from_pretrained`] to enable tensor parallelism for inference.
|
||||
## Using 🤗 transformers
|
||||
|
||||
```py
|
||||
Transformers provides a simple interface to use for tensor parallelism. We provide multiple classes implementing different partitioning
|
||||
strategies and a simple entrypoint to parallelize `nn.Module` instance. You won't have to interact with this interface directly, everything is done in `PretrainedModel.from_pretrained` method for you. This section will first talk about the partitioning strategies
|
||||
we support, then the user interface you will be interacting with, and finally it will teach you how to extend it with your own partitioning
|
||||
strategies.
|
||||
|
||||
### Partitioning strategies
|
||||
|
||||
In transformers, partitioning strategies reside in a class `ParallelInterface` which works like a mapping from string to the strategy implementation.
|
||||
|
||||
|
||||
```python
|
||||
class ParallelInterface(MutableMapping):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
|
||||
"""
|
||||
_global_mapping = {
|
||||
"colwise": ColwiseParallel(),
|
||||
"rowwise": RowwiseParallel(),
|
||||
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
|
||||
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
|
||||
"local_colwise": ColwiseParallel(use_dtensor=False),
|
||||
"local_rowwise": RowwiseParallel(use_dtensor=False),
|
||||
"local": IsolatedParallel(),
|
||||
"gather": GatherParallel(),
|
||||
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
|
||||
"sequence_parallel": SequenceParallel(),
|
||||
"replicate": ReplicateParallel(),
|
||||
}
|
||||
```
|
||||
|
||||
We support the following strategies:
|
||||
|
||||
- `ColwiseParallel` - A simple column-wise partitioning, being able to handle both weights and biases, does exactly what we've discussed before.
|
||||
- `RowwiseParallel` - Again, row-wise partitioning as dicussed before, supports weights and biases, on top of that it also supports `nn.Embedding` modules.
|
||||
- `SequenceParallel` - Sequence parallel implementation, for support of `LayerNorm` and `Dropout` layers. Also supports Python implementation of `RMSNorm` (see [this](https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34))
|
||||
- `PackedColwiseParallel` - A variant of column-wise partitioning, however it works on packed weights (i.e. `up_proj` and `gate_proj` being packed together). For more details, see [this comment](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py#L79-#L108)
|
||||
- `PackedRowwiseParallel` - A variant of row-wise partitioning, works on packed weights, for more details check the comment linked above.
|
||||
- `GatherParallel` - A very simple class, that only makes the outputs of the module to be gathered across devices.
|
||||
- `IsolatedParallel` - This is a special case, where we want to *isolate* the module from the rest of the devices (world). This is used for Experts in MoE layers, basically creating Expert parallelism of sorts.
|
||||
- `ReplicateParallel` - Many `torch.distributed` APIs break if model is partially sharded, so this class is used to replicate the module across all devices.
|
||||
|
||||
### Sharding a model
|
||||
|
||||
We provide two ways to shard a model, first one is to use `auto` tensor parallelism plan, which will automatically shard the model based on our predefined configuration. This requires the model to have predefined tensor parallel plan in transformers.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # better for smaller number of GPUs
|
||||
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # better to visualize all the possible strategies
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan="auto")
|
||||
|
||||
print(model._tp_plan)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> For a list of models that support tensor parallelism, see the [Supported models](#supported-models) section above.
|
||||
|
||||
The second way is to manually specify your own partitioning plan.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
tp_plan = {
|
||||
"model.layers.*.self_attn.q_proj": "colwise",
|
||||
"model.layers.*.self_attn.k_proj": "colwise",
|
||||
"model.layers.*.self_attn.v_proj": "colwise",
|
||||
"model.layers.*.self_attn.o_proj": "rowwise",
|
||||
...
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)
|
||||
|
||||
print(model._tp_plan)
|
||||
```
|
||||
|
||||
You might have noticed that there are some special cases in the `ParallelInterface` mapping, let's now talk about them. This will help you understand their purpose and help with extending to other strategies.
|
||||
|
||||
### PackedRowwiseParallel
|
||||
This class is a special case of `RowwiseParallel`, it's used to shard packed weights. Weight packing is a common technique used in models. It's a technique where we pack multiple linear layers into a single, bigger one.
|
||||
|
||||
For example in `Llama4` model, we pack `up_proj` and `gate_proj` into a single `gate_up_proj` module.
|
||||
```python
|
||||
class Llama4TextExperts(nn.Module):
|
||||
...
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
|
||||
```
|
||||
|
||||
Then in forward, we can use batch matrix multiplication to compute the output of the `gate_up_proj` module.
|
||||
|
||||
```python
|
||||
def forward(self, hidden_states):
|
||||
...
|
||||
gate_up = torch.bmm(hidden_states, self.gate_up_proj) # Compute the output of the gate_up_proj module
|
||||
gate, up = gate_up.chunk(2, dim=-1) # Split the output into gate and up
|
||||
```
|
||||
|
||||
In this case, we need to use the `PackedRowwiseParallel` strategy to shard the `gate_up_proj` module, as using a simple `RowwiseParallel` will shard the layers wrongly.
|
||||
|
||||
> [!TIP]
|
||||
> If this is a bit difficult to wrap your head around, check out [this comment](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py#L79-#L108) for an amazing visual representation of why `Packed*` needs to be used.
|
||||
|
||||
|
||||
### `local*` strategies
|
||||
|
||||
You could have noticed that there are `local*` strategies, which use the same layers as `*` strategy, but don't use `DTensor` at all.
|
||||
This is because `DTensor` is not supported for some of the operations: such as `torch.chunk`. Therefore, sometimes we need to use the `local*` strategies, which use vanilla `torch.Tensor` and do some of the distributed logic manually.
|
||||
|
||||
<!---
|
||||
Readd this when I get the exact error message
|
||||
> [!TIP]
|
||||
> If you are using a custom partitioning strategy, and it's not working with `... is not supported` error, try using the `local*` strategies to see if they work better.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> Manually specifying your own partitiong plan requires a good understanding of the model architecture and how the partitioning strategies interact together. If you are not sure about this, the resulting model can be very slow, even failing or incorrect. Again, refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) which can teach you everything required.
|
||||
|
||||
### Extending the interface with your own partitioning strategies
|
||||
|
||||
This is a very advanced topic, which requires a good understanding of distributed collectives and the model architecture.
|
||||
Your custom partitioning strategy should inherit from `TensorParallelLayer` defined in [integrations/tensor_parallel.py](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py) and implement: `partition_tensor`, `_prepare_input_fn` and `_prepare_output_fn`. Then it should be registered in the `ParallelInterface` mapping, so our dispatching logic can find it when specified in the `tp_plan`.
|
||||
|
||||
Let's go through this workflow step by step, on an already existing example: `ColwiseParallel`.
|
||||
|
||||
1. Inherit from `TensorParallelLayer` and initialization
|
||||
|
||||
```python
|
||||
class ColwiseParallel(TensorParallelLayer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[Placement] = None, # The input layout coming from the previous layer
|
||||
output_layouts: Optional[Placement] = None, # The output layout we want to achieve
|
||||
use_local_output: bool = True, # Whether to use local output or not
|
||||
use_dtensor=True, # Whether to use DTensor or not
|
||||
):
|
||||
self.input_layouts = (input_layouts or Replicate(),) # The input sharding coming from the previous layer
|
||||
self.output_layouts = (output_layouts or Shard(-1),) # Desired output sharding
|
||||
self.desired_input_layouts = (Replicate(),) # Desired input sharding, inputs should be replicated across GPUs
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
```
|
||||
|
||||
In the `__init__` method, we define these attributes, where `input_layouts` and `output_layouts` describing, how the input and output tensors should be placed on the devices. `desired_input_layouts` is used to specify, how the input *SHOULD* be placed on the devices.
|
||||
|
||||
2a. Implement `partition_tensor` method
|
||||
|
||||
```python
|
||||
def partition_tensor(
|
||||
self,
|
||||
param, # Full tensor of the parameter
|
||||
empty_param, # Empty tensor of the parameter, will be filled with the partitioned tensor
|
||||
param_type, # Type of the parameter, `bias` or `weight`
|
||||
param_casting_dtype, # The type to cast the parameter to
|
||||
to_contiguous, # Whether to convert the tensor to a contiguous memory layout
|
||||
rank, # The rank of the current device
|
||||
device_mesh, # The device mesh
|
||||
) -> nn.Parameter: # Return the partitioned parameter
|
||||
...
|
||||
```
|
||||
|
||||
This method is used to partition the tensor, and fill the `empty_param` with the partitioned tensor.
|
||||
We provide some utility functions to help you with this, such as `get_tensor_shard` which will get you the correct shard of the original parameter for this rank or `get_packed_weights` to help with packed weights.
|
||||
|
||||
2b. Implement `_prepare_input_fn` and `_prepare_output_fn` methods
|
||||
|
||||
These methods are used as [`pre-forward`](https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_pre_hook.html) and [`forward`](https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html) hooks respectively. Their purpose is to re-distribute the inputs and outputs to the desired layout, passed in the `__init__` method.
|
||||
|
||||
```python
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
...
|
||||
# Do some custom logic, cast to DTensor etc.
|
||||
...
|
||||
return inputs.redistribute(placements=desired_input_layouts, device_mesh=device_mesh)
|
||||
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
...
|
||||
# Do some custom logic, cast to DTensor etc.
|
||||
...
|
||||
return outputs.redistribute(placements=output_layouts, device_mesh=device_mesh)
|
||||
```
|
||||
|
||||
3. Register the strategy
|
||||
Congratulations! You've implemented your own partitioning strategy. Now, to use it with your own `tp_plan`, you need to register it in the `ParallelInterface` mapping.
|
||||
|
||||
```python
|
||||
from transformers.integrations.tensor_parallel import ParallelInterface
|
||||
|
||||
ParallelInterface.register_strategy("colwise_custom", ColwiseParallel)
|
||||
```
|
||||
|
||||
And now you can use it in your `tp_plan` as such:
|
||||
|
||||
```python
|
||||
tp_plan = {
|
||||
"model.layers.*.self_attn.q_proj": "colwise_custom",
|
||||
...
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)
|
||||
```
|
||||
|
||||
|
||||
## Full example
|
||||
|
||||
Let's go through a full example of inference with tensor parallelism.
|
||||
```python
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@ -66,17 +281,49 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
|
||||
torchrun --nproc-per-node 4 demo.py
|
||||
```
|
||||
|
||||
For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
|
||||
```bash
|
||||
export OMP_NUM_THREADS=56
|
||||
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
|
||||
```
|
||||
The CPU benchmark data will be released soon.
|
||||
|
||||
You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.
|
||||
|
||||
For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct%2C%20seqlen%20%3D%20512%2C%20python%2C%20w_%20compile.png">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct%2C%20seqlen%20%3D%20512%2C%20python%2C%20w_%20compile.png">
|
||||
</div>
|
||||
|
||||
## Tensor parallelism in-depth
|
||||
Our implementation of tensor parallelism is framework-agnostic in design, but the specific implementations we've developed rely on the torch.distributed package. We heavily utilize abstractions such as `DeviceMesh` or `DTensor` to provide a simple and extensible interface to the user.
|
||||
|
||||
### DeviceMesh
|
||||
Imagine `DeviceMesh` as a multi-dimensional grid of devices that communicate together. Different parallelization strategies require different types of communication patterns, therefore we can create a `DeviceMesh` with multiple submeshes:
|
||||
```python
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# Create a 1D mesh of 4 GPUs
|
||||
device_mesh = init_device_mesh("cuda", (4,), mesh_dim_names=["tp"])
|
||||
```
|
||||
Then, most of the `torch.distributed` defined parallelization strategies can be applied to a mesh itself, or its submesh, automatically handling the communication patterns.
|
||||
|
||||
### DTensor
|
||||
|
||||
Abbreviation for Distributed Tensor, `DTensor` is a tensor subclass that handles the distributed logic on-top of the usual tensor operations. Most of the model weights in case of tensor parallelism are stored as `DTensor`s (with some exceptions, more on that later).
|
||||
The most important part of DTensor, that is crucial to understand, is the `placement` attribute. It's an attribute that tells PyTorch how is the tensor placed on the devices of the `DeviceMesh`.
|
||||
|
||||
It can have the following values:
|
||||
|
||||
- `Shard(dimension)` - Annotates that this `DTensor` is sharded across a given dimension, over the `DeviceMesh` it was constructed under. For example, if we would like to shard weights for column-wise partitioning, we would do:
|
||||
```python
|
||||
weight = ...
|
||||
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(0)]) # Shard across the 1st (column-wise) dimension
|
||||
bias = ...
|
||||
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Shard(-1)]) # Shard across the ONLY dimension
|
||||
```
|
||||
|
||||
To give another example, for row-wise partitioning, we would do:
|
||||
```python
|
||||
weight = ...
|
||||
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(1)]) # Shard across the 2nd (row-wise) dimension
|
||||
bias = ...
|
||||
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
|
||||
```
|
||||
|
||||
- `Replicate()` - Annotates that this `DTensor` is replicated across the `DeviceMesh`. Very straight-forward, only creates a full copy of the tensor on each device.
|
||||
- `Partial()` - This placement is mostly of no interest to us, it's used to annotate that this tensor is pending a reduction operation.
|
||||
|
@ -106,6 +106,8 @@ dataset[0]["text"]
|
||||
Remember to resample the sampling rate to match the pretrained models required sampling rate.
|
||||
|
||||
```py
|
||||
from datasets import Audio
|
||||
|
||||
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
|
||||
```
|
||||
|
||||
|
@ -29,8 +29,6 @@
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: tasks/sequence_classification
|
||||
title: テキストの分類
|
||||
- local: tasks/token_classification
|
||||
title: トークンの分類
|
||||
- local: tasks/question_answering
|
||||
|
@ -47,7 +47,7 @@ ALBERTモデルは、「[ALBERT: A Lite BERT for Self-supervised Learning of Lan
|
||||
|
||||
## 参考資料
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問応答タスクガイド](../tasks/question_answering)
|
||||
- [マスクされた言語モデルタスクガイド](../tasks/masked_language_modeling)
|
||||
|
@ -129,7 +129,7 @@ BART を始めるのに役立つ公式 Hugging Face およびコミュニティ
|
||||
- [翻訳タスクガイド](../tasks/translation)
|
||||
|
||||
以下も参照してください。
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
- [抽出されたチェックポイント](https://huggingface.co/models?search=distilbart) は、この [論文](https://arxiv.org/abs/2010.13002) で説明されています。
|
||||
|
@ -76,7 +76,7 @@ BERT を始めるのに役立つ公式 Hugging Face およびコミュニティ
|
||||
- [`BertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb)。
|
||||
- [`TFBertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb)。
|
||||
- [`FlaxBertForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/flax/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification_flax.ipynb)。
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
|
||||
<PipelineTag pipeline="token-classification"/>
|
||||
|
||||
|
@ -58,7 +58,7 @@ BigBird は、質問応答や要約などのさまざまな NLP タスクのパ
|
||||
|
||||
## ドキュメント リソース
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
|
@ -58,7 +58,7 @@ BigBird は、質問応答や要約などのさまざまな NLP タスクのパ
|
||||
|
||||
## ドキュメント リソース
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
- [翻訳タスクガイド](../tasks/translation)
|
||||
|
@ -39,7 +39,7 @@ BLOOM を使い始めるのに役立つ公式 Hugging Face およびコミュニ
|
||||
|
||||
以下も参照してください。
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
|
||||
|
@ -46,7 +46,7 @@ Bi-direction Encoders for Transformers (BERT) のフランス語版である Cam
|
||||
|
||||
## Resources
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
|
@ -98,7 +98,7 @@ CANINE は生の文字で動作するため、**トークナイザーなし**で
|
||||
|
||||
## Resources
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [多肢選択タスク ガイド](../tasks/multiple_choice)
|
||||
|
@ -53,7 +53,7 @@ ConvBERT トレーニングのヒントは BERT のヒントと似ています
|
||||
|
||||
## Resources
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [マスクされた言語モデリング タスク ガイド](../tasks/masked_lang_modeling)
|
||||
|
@ -61,7 +61,7 @@ CTRL モデルは、Nitish Shirish Keskar*、Bryan McCann*、Lav R. Varshney、C
|
||||
|
||||
## Resources
|
||||
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
|
||||
## CTRLConfig
|
||||
|
@ -58,7 +58,7 @@ Data2Vec の使用を開始するのに役立つ公式 Hugging Face およびコ
|
||||
- カスタム データセットで [`TFData2VecVisionForImageClassification`] を微調整するには、[このノートブック](https://colab.research.google.com/github/sayakpaul/TF-2.0-Hacks/blob/master/data2vec_vision_image_classification.ipynb) を参照してください。 )。
|
||||
|
||||
**Data2VecText ドキュメント リソース**
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [因果言語モデリング タスク ガイド](../tasks/language_modeling)
|
||||
|
@ -61,7 +61,7 @@ v2 の新機能:
|
||||
[kamalkraj](https://huggingface.co/kamalkraj) による投稿。元のコードは [こちら](https://github.com/microsoft/DeBERTa) にあります。
|
||||
|
||||
## Resources
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
- [トークン分類タスクガイド](../tasks/token_classification)
|
||||
- [質問回答タスク ガイド](../tasks/question_answering)
|
||||
- [マスク言語モデリング タスク ガイド](../tasks/masked_language_modeling)
|
||||
|
@ -52,7 +52,7 @@ DeBERTa を使い始めるのに役立つ公式 Hugging Face およびコミュ
|
||||
- DeBERTa による [機械学習によるスーパーチャージされた顧客サービス](https://huggingface.co/blog/supercharge-customer-service-with-machine-learning) に関するブログ投稿。
|
||||
- [`DebertaForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb)。
|
||||
- [`TFDebertaForSequenceClassification`] は、この [サンプル スクリプト](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) および [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb)。
|
||||
- [テキスト分類タスクガイド](../tasks/sequence_classification)
|
||||
- [テキスト分類タスクガイド(英語版)](../../en/tasks/sequence_classification)
|
||||
|
||||
<PipelineTag pipeline="token-classification" />
|
||||
|
||||
|
@ -1,604 +0,0 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Sequence classification
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
<Youtube id="dKE8SIt9C-w"/>
|
||||
|
||||
セマンティック セグメンテーションでは、画像の個々のピクセルにラベルまたはクラスを割り当てます。セグメンテーションにはいくつかのタイプがありますが、セマンティック セグメンテーションの場合、同じオブジェクトの一意のインスタンス間の区別は行われません。両方のオブジェクトに同じラベルが付けられます (たとえば、「car-1」と「car-2」の代わりに「car」)。セマンティック セグメンテーションの一般的な現実世界のアプリケーションには、歩行者や重要な交通情報を識別するための自動運転車のトレーニング、医療画像内の細胞と異常の識別、衛星画像からの環境変化の監視などが含まれます。
|
||||
|
||||
このガイドでは、次の方法を説明します。
|
||||
|
||||
1. [SceneParse150](https://huggingface.co/datasets/scene_parse_150) データセットの [SegFormer](https://huggingface.co/docs/transformers/main/en/model_doc/segformer#segformer) を微調整します。
|
||||
2. 微調整したモデルを推論に使用します。
|
||||
|
||||
<Tip>
|
||||
|
||||
このタスクと互換性のあるすべてのアーキテクチャとチェックポイントを確認するには、[タスクページ](https://huggingface.co/tasks/text-classification) を確認することをお勧めします。
|
||||
|
||||
</Tip>
|
||||
|
||||
始める前に、必要なライブラリがすべてインストールされていることを確認してください。
|
||||
|
||||
```bash
|
||||
pip install -q datasets transformers evaluate
|
||||
```
|
||||
|
||||
モデルをアップロードしてコミュニティと共有できるように、Hugging Face アカウントにログインすることをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import notebook_login
|
||||
|
||||
>>> notebook_login()
|
||||
```
|
||||
|
||||
## Load SceneParse150 dataset
|
||||
|
||||
|
||||
まず、SceneParse150 データセットの小さいサブセットを 🤗 データセット ライブラリから読み込みます。これにより、完全なデータセットのトレーニングにさらに時間を費やす前に、実験してすべてが機能することを確認する機会が得られます。
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> ds = load_dataset("scene_parse_150", split="train[:50]")
|
||||
```
|
||||
|
||||
[`~datasets.Dataset.train_test_split`] メソッドを使用して、データセットの `train` 分割をトレイン セットとテスト セットに分割します。
|
||||
|
||||
```py
|
||||
>>> ds = ds.train_test_split(test_size=0.2)
|
||||
>>> train_ds = ds["train"]
|
||||
>>> test_ds = ds["test"]
|
||||
```
|
||||
|
||||
次に、例を見てみましょう。
|
||||
|
||||
```py
|
||||
>>> train_ds[0]
|
||||
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x683 at 0x7F9B0C201F90>,
|
||||
'annotation': <PIL.PngImagePlugin.PngImageFile image mode=L size=512x683 at 0x7F9B0C201DD0>,
|
||||
'scene_category': 368}
|
||||
```
|
||||
|
||||
- `image`: シーンの PIL イメージ。
|
||||
- `annotation`: セグメンテーション マップの PIL イメージ。モデルのターゲットでもあります。
|
||||
- `scene_category`: 「キッチン」や「オフィス」などの画像シーンを説明するカテゴリ ID。このガイドでは、「image」と「annotation」のみが必要になります。どちらも PIL イメージです。
|
||||
|
||||
また、ラベル ID をラベル クラスにマップする辞書を作成することもできます。これは、後でモデルを設定するときに役立ちます。ハブからマッピングをダウンロードし、`id2label` および `label2id` ディクショナリを作成します。
|
||||
|
||||
```py
|
||||
>>> import json
|
||||
>>> from pathlib import Path
|
||||
>>> from huggingface_hub import hf_hub_download
|
||||
|
||||
>>> repo_id = "huggingface/label-files"
|
||||
>>> filename = "ade20k-id2label.json"
|
||||
>>> id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
|
||||
>>> id2label = {int(k): v for k, v in id2label.items()}
|
||||
>>> label2id = {v: k for k, v in id2label.items()}
|
||||
>>> num_labels = len(id2label)
|
||||
```
|
||||
|
||||
## Preprocess
|
||||
|
||||
次のステップでは、SegFormer 画像プロセッサをロードして、モデルの画像と注釈を準備します。このデータセットのような一部のデータセットは、バックグラウンド クラスとしてゼロインデックスを使用します。ただし、実際には背景クラスは 150 個のクラスに含まれていないため、`do_reduce_labels=True`を設定してすべてのラベルから 1 つを引く必要があります。ゼロインデックスは `255` に置き換えられるため、SegFormer の損失関数によって無視されます。
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
|
||||
>>> checkpoint = "nvidia/mit-b0"
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
|
||||
```
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
モデルを過学習に対してより堅牢にするために、画像データセットにいくつかのデータ拡張を適用するのが一般的です。このガイドでは、[torchvision](https://pytorch.org) の [`ColorJitter`](https://pytorch.org/vision/stable/generated/torchvision.transforms.ColorJitter.html) 関数を使用します。 /vision/stable/index.html) を使用して画像の色のプロパティをランダムに変更しますが、任意の画像ライブラリを使用することもできます。
|
||||
|
||||
```py
|
||||
>>> from torchvision.transforms import ColorJitter
|
||||
|
||||
>>> jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
|
||||
```
|
||||
|
||||
次に、モデルの画像と注釈を準備するための 2 つの前処理関数を作成します。これらの関数は、画像を`pixel_values`に変換し、注釈を`labels`に変換します。トレーニング セットの場合、画像を画像プロセッサに提供する前に`jitter`が適用されます。テスト セットの場合、テスト中にデータ拡張が適用されないため、画像プロセッサは`images`を切り取って正規化し、`labels` のみを切り取ります。
|
||||
|
||||
```py
|
||||
>>> def train_transforms(example_batch):
|
||||
... images = [jitter(x) for x in example_batch["image"]]
|
||||
... labels = [x for x in example_batch["annotation"]]
|
||||
... inputs = image_processor(images, labels)
|
||||
... return inputs
|
||||
|
||||
|
||||
>>> def val_transforms(example_batch):
|
||||
... images = [x for x in example_batch["image"]]
|
||||
... labels = [x for x in example_batch["annotation"]]
|
||||
... inputs = image_processor(images, labels)
|
||||
... return inputs
|
||||
```
|
||||
|
||||
データセット全体に`jitter`を適用するには、🤗 Datasets [`~datasets.Dataset.set_transform`] 関数を使用します。変換はオンザフライで適用されるため、高速で消費するディスク容量が少なくなります。
|
||||
|
||||
```py
|
||||
>>> train_ds.set_transform(train_transforms)
|
||||
>>> test_ds.set_transform(val_transforms)
|
||||
```
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
|
||||
モデルを過学習に対してより堅牢にするために、画像データセットにいくつかのデータ拡張を適用するのが一般的です。
|
||||
このガイドでは、[`tf.image`](https://www.tensorflow.org/api_docs/python/tf/image) を使用して画像の色のプロパティをランダムに変更しますが、任意のプロパティを使用することもできます。画像
|
||||
好きな図書館。
|
||||
2 つの別々の変換関数を定義します。
|
||||
- 画像拡張を含むトレーニング データ変換
|
||||
- 🤗 Transformers のコンピューター ビジョン モデルはチャネル優先のレイアウトを想定しているため、画像を転置するだけの検証データ変換
|
||||
|
||||
```py
|
||||
>>> import tensorflow as tf
|
||||
|
||||
|
||||
>>> def aug_transforms(image):
|
||||
... image = tf.keras.utils.img_to_array(image)
|
||||
... image = tf.image.random_brightness(image, 0.25)
|
||||
... image = tf.image.random_contrast(image, 0.5, 2.0)
|
||||
... image = tf.image.random_saturation(image, 0.75, 1.25)
|
||||
... image = tf.image.random_hue(image, 0.1)
|
||||
... image = tf.transpose(image, (2, 0, 1))
|
||||
... return image
|
||||
|
||||
|
||||
>>> def transforms(image):
|
||||
... image = tf.keras.utils.img_to_array(image)
|
||||
... image = tf.transpose(image, (2, 0, 1))
|
||||
... return image
|
||||
```
|
||||
|
||||
次に、モデルの画像と注釈のバッチを準備する 2 つの前処理関数を作成します。これらの機能が適用されます
|
||||
画像変換を行い、以前にロードされた `image_processor` を使用して画像を `pixel_values` に変換し、
|
||||
`labels`への注釈。 `ImageProcessor` は、画像のサイズ変更と正規化も処理します。
|
||||
|
||||
```py
|
||||
>>> def train_transforms(example_batch):
|
||||
... images = [aug_transforms(x.convert("RGB")) for x in example_batch["image"]]
|
||||
... labels = [x for x in example_batch["annotation"]]
|
||||
... inputs = image_processor(images, labels)
|
||||
... return inputs
|
||||
|
||||
|
||||
>>> def val_transforms(example_batch):
|
||||
... images = [transforms(x.convert("RGB")) for x in example_batch["image"]]
|
||||
... labels = [x for x in example_batch["annotation"]]
|
||||
... inputs = image_processor(images, labels)
|
||||
... return inputs
|
||||
```
|
||||
|
||||
データセット全体に前処理変換を適用するには、🤗 Datasets [`~datasets.Dataset.set_transform`] 関数を使用します。
|
||||
変換はオンザフライで適用されるため、高速で消費するディスク容量が少なくなります。
|
||||
|
||||
```py
|
||||
>>> train_ds.set_transform(train_transforms)
|
||||
>>> test_ds.set_transform(val_transforms)
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
## Evaluate
|
||||
|
||||
トレーニング中にメトリクスを含めると、多くの場合、モデルのパフォーマンスを評価するのに役立ちます。 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) ライブラリを使用して、評価メソッドをすばやくロードできます。このタスクでは、[Mean Intersection over Union](https://huggingface.co/spaces/evaluate-metric/accuracy) (IoU) メトリックをロードします (🤗 Evaluate [クイック ツアー](https://huggingface.co) を参照してください) /docs/evaluate/a_quick_tour) を参照して、メトリクスをロードして計算する方法の詳細を確認してください)。
|
||||
|
||||
```py
|
||||
>>> import evaluate
|
||||
|
||||
>>> metric = evaluate.load("mean_iou")
|
||||
```
|
||||
|
||||
次に、メトリクスを [`~evaluate.EvaluationModule.compute`] する関数を作成します。予測を次のように変換する必要があります
|
||||
最初にロジットを作成し、次に [`~evaluate.EvaluationModule.compute`] を呼び出す前にラベルのサイズに一致するように再形成します。
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
```py
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
>>> from torch import nn
|
||||
|
||||
>>> def compute_metrics(eval_pred):
|
||||
... with torch.no_grad():
|
||||
... logits, labels = eval_pred
|
||||
... logits_tensor = torch.from_numpy(logits)
|
||||
... logits_tensor = nn.functional.interpolate(
|
||||
... logits_tensor,
|
||||
... size=labels.shape[-2:],
|
||||
... mode="bilinear",
|
||||
... align_corners=False,
|
||||
... ).argmax(dim=1)
|
||||
|
||||
... pred_labels = logits_tensor.detach().cpu().numpy()
|
||||
... metrics = metric.compute(
|
||||
... predictions=pred_labels,
|
||||
... references=labels,
|
||||
... num_labels=num_labels,
|
||||
... ignore_index=255,
|
||||
... reduce_labels=False,
|
||||
... )
|
||||
... for key, value in metrics.items():
|
||||
... if type(value) is np.ndarray:
|
||||
... metrics[key] = value.tolist()
|
||||
... return metrics
|
||||
```
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
|
||||
```py
|
||||
>>> def compute_metrics(eval_pred):
|
||||
... logits, labels = eval_pred
|
||||
... logits = tf.transpose(logits, perm=[0, 2, 3, 1])
|
||||
... logits_resized = tf.image.resize(
|
||||
... logits,
|
||||
... size=tf.shape(labels)[1:],
|
||||
... method="bilinear",
|
||||
... )
|
||||
|
||||
... pred_labels = tf.argmax(logits_resized, axis=-1)
|
||||
... metrics = metric.compute(
|
||||
... predictions=pred_labels,
|
||||
... references=labels,
|
||||
... num_labels=num_labels,
|
||||
... ignore_index=-1,
|
||||
... reduce_labels=image_processor.do_reduce_labels,
|
||||
... )
|
||||
|
||||
... per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
|
||||
... per_category_iou = metrics.pop("per_category_iou").tolist()
|
||||
|
||||
... metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
|
||||
... metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
|
||||
... return {"val_" + k: v for k, v in metrics.items()}
|
||||
```
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
これで`compute_metrics`関数の準備が整いました。トレーニングをセットアップするときにこの関数に戻ります。
|
||||
|
||||
## Train
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
<Tip>
|
||||
|
||||
[`Trainer`] を使用したモデルの微調整に慣れていない場合は、[こちら](../training#finetune-with-trainer) の基本的なチュートリアルをご覧ください。
|
||||
|
||||
|
||||
</Tip>
|
||||
|
||||
これでモデルのトレーニングを開始する準備が整いました。 [`AutoModelForSemanticSegmentation`] を使用して SegFormer をロードし、ラベル ID とラベル クラス間のマッピングをモデルに渡します。
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer
|
||||
|
||||
>>> model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)
|
||||
```
|
||||
|
||||
この時点で残っている手順は次の 3 つだけです。
|
||||
|
||||
1. [`TrainingArguments`] でトレーニング ハイパーパラメータを定義します。 `image` 列が削除されるため、未使用の列を削除しないことが重要です。 `image` 列がないと、`pixel_values` を作成できません。この動作を防ぐには、`remove_unused_columns=False`を設定してください。他に必要なパラメータは、モデルの保存場所を指定する `output_dir` だけです。 `push_to_hub=True`を設定して、このモデルをハブにプッシュします (モデルをアップロードするには、Hugging Face にサインインする必要があります)。各エポックの終了時に、[`Trainer`] は IoU メトリックを評価し、トレーニング チェックポイントを保存します。
|
||||
2. トレーニング引数を、モデル、データセット、トークナイザー、データ照合器、および `compute_metrics` 関数とともに [`Trainer`] に渡します。
|
||||
3. [`~Trainer.train`] を呼び出してモデルを微調整します。
|
||||
|
||||
|
||||
```py
|
||||
>>> training_args = TrainingArguments(
|
||||
... output_dir="segformer-b0-scene-parse-150",
|
||||
... learning_rate=6e-5,
|
||||
... num_train_epochs=50,
|
||||
... per_device_train_batch_size=2,
|
||||
... per_device_eval_batch_size=2,
|
||||
... save_total_limit=3,
|
||||
... eval_strategy="steps",
|
||||
... save_strategy="steps",
|
||||
... save_steps=20,
|
||||
... eval_steps=20,
|
||||
... logging_steps=1,
|
||||
... eval_accumulation_steps=5,
|
||||
... remove_unused_columns=False,
|
||||
... push_to_hub=True,
|
||||
... )
|
||||
|
||||
>>> trainer = Trainer(
|
||||
... model=model,
|
||||
... args=training_args,
|
||||
... train_dataset=train_ds,
|
||||
... eval_dataset=test_ds,
|
||||
... compute_metrics=compute_metrics,
|
||||
... )
|
||||
|
||||
>>> trainer.train()
|
||||
```
|
||||
|
||||
トレーニングが完了したら、 [`~transformers.Trainer.push_to_hub`] メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できるようにします。
|
||||
|
||||
```py
|
||||
>>> trainer.push_to_hub()
|
||||
```
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
<Tip>
|
||||
|
||||
Keras を使用したモデルの微調整に慣れていない場合は、まず [基本チュートリアル](./training#train-a-tensorflow-model-with-keras) を確認してください。
|
||||
|
||||
</Tip>
|
||||
|
||||
TensorFlow でモデルを微調整するには、次の手順に従います。
|
||||
1. トレーニングのハイパーパラメータを定義し、オプティマイザーと学習率スケジュールを設定します。
|
||||
2. 事前トレーニングされたモデルをインスタンス化します。
|
||||
3. 🤗 データセットを `tf.data.Dataset` に変換します。
|
||||
4. モデルをコンパイルします。
|
||||
5. コールバックを追加してメトリクスを計算し、モデルを 🤗 Hub にアップロードします
|
||||
6. `fit()` メソッドを使用してトレーニングを実行します。
|
||||
|
||||
まず、ハイパーパラメーター、オプティマイザー、学習率スケジュールを定義します。
|
||||
|
||||
|
||||
```py
|
||||
>>> from transformers import create_optimizer
|
||||
|
||||
>>> batch_size = 2
|
||||
>>> num_epochs = 50
|
||||
>>> num_train_steps = len(train_ds) * num_epochs
|
||||
>>> learning_rate = 6e-5
|
||||
>>> weight_decay_rate = 0.01
|
||||
|
||||
>>> optimizer, lr_schedule = create_optimizer(
|
||||
... init_lr=learning_rate,
|
||||
... num_train_steps=num_train_steps,
|
||||
... weight_decay_rate=weight_decay_rate,
|
||||
... num_warmup_steps=0,
|
||||
... )
|
||||
```
|
||||
|
||||
次に、ラベル マッピングとともに [`TFAutoModelForSemanticSegmentation`] を使用して SegFormer をロードし、それをコンパイルします。
|
||||
オプティマイザ。 Transformers モデルにはすべてデフォルトのタスク関連の損失関数があるため、次の場合を除き、損失関数を指定する必要はないことに注意してください。
|
||||
|
||||
```py
|
||||
>>> from transformers import TFAutoModelForSemanticSegmentation
|
||||
|
||||
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
|
||||
... checkpoint,
|
||||
... id2label=id2label,
|
||||
... label2id=label2id,
|
||||
... )
|
||||
>>> model.compile(optimizer=optimizer) # No loss argument!
|
||||
```
|
||||
|
||||
[`~datasets.Dataset.to_tf_dataset`] と [`DefaultDataCollator`] を使用して、データセットを `tf.data.Dataset` 形式に変換します。
|
||||
|
||||
```py
|
||||
>>> from transformers import DefaultDataCollator
|
||||
|
||||
>>> data_collator = DefaultDataCollator(return_tensors="tf")
|
||||
|
||||
>>> tf_train_dataset = train_ds.to_tf_dataset(
|
||||
... columns=["pixel_values", "label"],
|
||||
... shuffle=True,
|
||||
... batch_size=batch_size,
|
||||
... collate_fn=data_collator,
|
||||
... )
|
||||
|
||||
>>> tf_eval_dataset = test_ds.to_tf_dataset(
|
||||
... columns=["pixel_values", "label"],
|
||||
... shuffle=True,
|
||||
... batch_size=batch_size,
|
||||
... collate_fn=data_collator,
|
||||
... )
|
||||
```
|
||||
|
||||
予測から精度を計算し、モデルを 🤗 ハブにプッシュするには、[Keras callbacks](../main_classes/keras_callbacks) を使用します。
|
||||
`compute_metrics` 関数を [`KerasMetricCallback`] に渡します。
|
||||
そして [`PushToHubCallback`] を使用してモデルをアップロードします。
|
||||
|
||||
```py
|
||||
>>> from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
|
||||
|
||||
>>> metric_callback = KerasMetricCallback(
|
||||
... metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
|
||||
... )
|
||||
|
||||
>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", image_processor=image_processor)
|
||||
|
||||
>>> callbacks = [metric_callback, push_to_hub_callback]
|
||||
```
|
||||
|
||||
ついに、モデルをトレーニングする準備が整いました。`fit()`トレーニングおよび検証データセット、エポック数、
|
||||
モデルを微調整するためのコールバック:
|
||||
|
||||
```py
|
||||
>>> model.fit(
|
||||
... tf_train_dataset,
|
||||
... validation_data=tf_eval_dataset,
|
||||
... callbacks=callbacks,
|
||||
... epochs=num_epochs,
|
||||
... )
|
||||
```
|
||||
|
||||
おめでとう!モデルを微調整し、🤗 Hub で共有しました。これで推論に使用できるようになりました。
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
|
||||
## Inference
|
||||
|
||||
モデルを微調整したので、それを推論に使用できるようになりました。
|
||||
|
||||
推論のために画像をロードします。
|
||||
|
||||
```py
|
||||
>>> image = ds[0]["image"]
|
||||
>>> image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-image.png" alt="Image of bedroom"/>
|
||||
</div>
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
推論用に微調整されたモデルを試す最も簡単な方法は、それを [`pipeline`] で使用することです。モデルを使用して画像セグメンテーション用の `pipeline` をインスタンス化し、それに画像を渡します。
|
||||
|
||||
```py
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> segmenter = pipeline("image-segmentation", model="my_awesome_seg_model")
|
||||
>>> segmenter(image)
|
||||
[{'score': None,
|
||||
'label': 'wall',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062690>},
|
||||
{'score': None,
|
||||
'label': 'sky',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A50>},
|
||||
{'score': None,
|
||||
'label': 'floor',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062B50>},
|
||||
{'score': None,
|
||||
'label': 'ceiling',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062A10>},
|
||||
{'score': None,
|
||||
'label': 'bed ',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E90>},
|
||||
{'score': None,
|
||||
'label': 'windowpane',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062390>},
|
||||
{'score': None,
|
||||
'label': 'cabinet',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062550>},
|
||||
{'score': None,
|
||||
'label': 'chair',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062D90>},
|
||||
{'score': None,
|
||||
'label': 'armchair',
|
||||
'mask': <PIL.Image.Image image mode=L size=640x427 at 0x7FD5B2062E10>}]
|
||||
```
|
||||
|
||||
必要に応じて、`pipeline` の結果を手動で複製することもできます。画像プロセッサで画像を処理し、`pixel_values`を GPU に配置します。
|
||||
|
||||
```py
|
||||
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use GPU if available, otherwise use a CPU
|
||||
>>> encoding = image_processor(image, return_tensors="pt")
|
||||
>>> pixel_values = encoding.pixel_values.to(device)
|
||||
```
|
||||
|
||||
入力をモデルに渡し、「logits」を返します。
|
||||
|
||||
```py
|
||||
>>> outputs = model(pixel_values=pixel_values)
|
||||
>>> logits = outputs.logits.cpu()
|
||||
```
|
||||
|
||||
次に、ロジットを元の画像サイズに再スケールします。
|
||||
|
||||
|
||||
```py
|
||||
>>> upsampled_logits = nn.functional.interpolate(
|
||||
... logits,
|
||||
... size=image.size[::-1],
|
||||
... mode="bilinear",
|
||||
... align_corners=False,
|
||||
... )
|
||||
|
||||
>>> pred_seg = upsampled_logits.argmax(dim=1)[0]
|
||||
```
|
||||
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
<frameworkcontent>
|
||||
<tf>
|
||||
|
||||
画像プロセッサをロードして画像を前処理し、入力を TensorFlow テンソルとして返します。
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoImageProcessor
|
||||
|
||||
>>> image_processor = AutoImageProcessor.from_pretrained("MariaK/scene_segmentation")
|
||||
>>> inputs = image_processor(image, return_tensors="tf")
|
||||
```
|
||||
|
||||
入力をモデルに渡し、`logits`を返します。
|
||||
|
||||
```py
|
||||
>>> from transformers import TFAutoModelForSemanticSegmentation
|
||||
|
||||
>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("MariaK/scene_segmentation")
|
||||
>>> logits = model(**inputs).logits
|
||||
```
|
||||
|
||||
次に、ロジットを元の画像サイズに再スケールし、クラス次元に argmax を適用します。
|
||||
|
||||
```py
|
||||
>>> logits = tf.transpose(logits, [0, 2, 3, 1])
|
||||
|
||||
>>> upsampled_logits = tf.image.resize(
|
||||
... logits,
|
||||
... # We reverse the shape of `image` because `image.size` returns width and height.
|
||||
... image.size[::-1],
|
||||
... )
|
||||
|
||||
>>> pred_seg = tf.math.argmax(upsampled_logits, axis=-1)[0]
|
||||
```
|
||||
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
結果を視覚化するには、[データセット カラー パレット](https://github.com/tensorflow/models/blob/3f1ca33afe3c1631b733ea7e40c294273b9e406d/research/deeplab/utils/get_dataset_colormap.py#L51) を、それぞれをマップする `ade_palette()` としてロードします。クラスを RGB 値に変換します。次に、画像と予測されたセグメンテーション マップを組み合わせてプロットできます。
|
||||
|
||||
```py
|
||||
>>> import matplotlib.pyplot as plt
|
||||
>>> import numpy as np
|
||||
|
||||
>>> color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
|
||||
>>> palette = np.array(ade_palette())
|
||||
>>> for label, color in enumerate(palette):
|
||||
... color_seg[pred_seg == label, :] = color
|
||||
>>> color_seg = color_seg[..., ::-1] # convert to BGR
|
||||
|
||||
>>> img = np.array(image) * 0.5 + color_seg * 0.5 # plot the image with the segmentation map
|
||||
>>> img = img.astype(np.uint8)
|
||||
|
||||
>>> plt.figure(figsize=(15, 10))
|
||||
>>> plt.imshow(img)
|
||||
>>> plt.show()
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/semantic-seg-preds.png" alt="Image of bedroom overlaid with segmentation map"/>
|
||||
</div>
|
@ -221,7 +221,7 @@ Transformerは最初に機械翻訳のために設計され、それ以降、ほ
|
||||
|
||||
事前訓練済みモデルをテキスト分類に使用するには、ベースのBERTモデルの上にシーケンス分類ヘッドを追加します。シーケンス分類ヘッドは最終的な隠れた状態を受け入れ、それらをロジットに変換するための線形層です。クロスエントロピー損失は、ロジットとターゲット間で最も可能性の高いラベルを見つけるために計算されます。
|
||||
|
||||
テキスト分類を試してみる準備はできましたか?DistilBERTを微調整し、推論に使用する方法を学ぶために、完全な[テキスト分類ガイド](tasks/sequence_classification)をチェックしてみてください!
|
||||
テキスト分類を試してみる準備はできましたか?DistilBERTを微調整し、推論に使用する方法を学ぶために、完全な[テキスト分類ガイド(英語版)](../en/tasks/sequence_classification)をチェックしてみてください!
|
||||
|
||||
### Token classification
|
||||
|
||||
|
@ -157,5 +157,8 @@
|
||||
title: 通用工具
|
||||
- local: internal/time_series_utils
|
||||
title: 时序数据工具
|
||||
- sections:
|
||||
- local: model_doc/bert
|
||||
title: BERT
|
||||
title: 内部辅助工具
|
||||
title: 应用程序接口 (API)
|
||||
title: 应用程序接口 (API)
|
258
docs/source/zh/model_doc/bert.md
Normal file
258
docs/source/zh/model_doc/bert.md
Normal file
@ -0,0 +1,258 @@
|
||||
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
|
||||
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
|
||||
">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# BERT
|
||||
|
||||
[BERT](https://huggingface.co/papers/1810.04805) 是一个在无标签的文本数据上预训练的双向 transformer,用于预测句子中被掩码的(masked) token,以及预测一个句子是否跟随在另一个句子之后。其主要思想是,在预训练过程中,通过随机掩码一些 token,让模型利用左右上下文的信息预测它们,从而获得更全面深入的理解。此外,BERT 具有很强的通用性,其学习到的语言表示可以通过额外的层或头进行微调,从而适配其他下游 NLP 任务。
|
||||
|
||||
你可以在 [BERT](https://huggingface.co/collections/google/bert-release-64ff5e7a4be99045d1896dbc) 集合下找到 BERT 的所有原始 checkpoint。
|
||||
|
||||
> [!TIP]
|
||||
> 点击右侧边栏中的 BERT 模型,以查看将 BERT 应用于不同语言任务的更多示例。
|
||||
|
||||
下面的示例演示了如何使用 [`Pipeline`], [`AutoModel`] 和命令行预测 `[MASK]` token。
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="fill-mask",
|
||||
model="google-bert/bert-base-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device=0
|
||||
)
|
||||
pipeline("Plants create [MASK] through a process known as photosynthesis.")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google-bert/bert-base-uncased",
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
"google-bert/bert-base-uncased",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
predictions = outputs.logits
|
||||
|
||||
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
|
||||
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
|
||||
predicted_token = tokenizer.decode(predicted_token_id)
|
||||
|
||||
print(f"The predicted token is: {predicted_token}")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "Plants create [MASK] through a process known as photosynthesis." | transformers-cli run --task fill-mask --model google-bert/bert-base-uncased --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## 注意
|
||||
|
||||
- 输入内容应在右侧进行填充,因为 BERT 使用绝对位置嵌入。
|
||||
## BertConfig
|
||||
|
||||
[[autodoc]] BertConfig
|
||||
- all
|
||||
|
||||
## BertTokenizer
|
||||
|
||||
[[autodoc]] BertTokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
|
||||
## BertTokenizerFast
|
||||
|
||||
[[autodoc]] BertTokenizerFast
|
||||
|
||||
## BertModel
|
||||
|
||||
[[autodoc]] BertModel
|
||||
- forward
|
||||
|
||||
## BertForPreTraining
|
||||
|
||||
[[autodoc]] BertForPreTraining
|
||||
- forward
|
||||
|
||||
## BertLMHeadModel
|
||||
|
||||
[[autodoc]] BertLMHeadModel
|
||||
- forward
|
||||
|
||||
## BertForMaskedLM
|
||||
|
||||
[[autodoc]] BertForMaskedLM
|
||||
- forward
|
||||
|
||||
## BertForNextSentencePrediction
|
||||
|
||||
[[autodoc]] BertForNextSentencePrediction
|
||||
- forward
|
||||
|
||||
## BertForSequenceClassification
|
||||
|
||||
[[autodoc]] BertForSequenceClassification
|
||||
- forward
|
||||
|
||||
## BertForMultipleChoice
|
||||
|
||||
[[autodoc]] BertForMultipleChoice
|
||||
- forward
|
||||
|
||||
## BertForTokenClassification
|
||||
|
||||
[[autodoc]] BertForTokenClassification
|
||||
- forward
|
||||
|
||||
## BertForQuestionAnswering
|
||||
|
||||
[[autodoc]] BertForQuestionAnswering
|
||||
- forward
|
||||
|
||||
## TFBertTokenizer
|
||||
|
||||
[[autodoc]] TFBertTokenizer
|
||||
|
||||
## TFBertModel
|
||||
|
||||
[[autodoc]] TFBertModel
|
||||
- call
|
||||
|
||||
## TFBertForPreTraining
|
||||
|
||||
[[autodoc]] TFBertForPreTraining
|
||||
- call
|
||||
|
||||
## TFBertModelLMHeadModel
|
||||
|
||||
[[autodoc]] TFBertLMHeadModel
|
||||
- call
|
||||
|
||||
## TFBertForMaskedLM
|
||||
|
||||
[[autodoc]] TFBertForMaskedLM
|
||||
- call
|
||||
|
||||
## TFBertForNextSentencePrediction
|
||||
|
||||
[[autodoc]] TFBertForNextSentencePrediction
|
||||
- call
|
||||
|
||||
## TFBertForSequenceClassification
|
||||
|
||||
[[autodoc]] TFBertForSequenceClassification
|
||||
- call
|
||||
|
||||
## TFBertForMultipleChoice
|
||||
|
||||
[[autodoc]] TFBertForMultipleChoice
|
||||
- call
|
||||
|
||||
## TFBertForTokenClassification
|
||||
|
||||
[[autodoc]] TFBertForTokenClassification
|
||||
- call
|
||||
|
||||
## TFBertForQuestionAnswering
|
||||
|
||||
[[autodoc]] TFBertForQuestionAnswering
|
||||
- call
|
||||
|
||||
## FlaxBertModel
|
||||
|
||||
[[autodoc]] FlaxBertModel
|
||||
- __call__
|
||||
|
||||
## FlaxBertForPreTraining
|
||||
|
||||
[[autodoc]] FlaxBertForPreTraining
|
||||
- __call__
|
||||
|
||||
## FlaxBertForCausalLM
|
||||
|
||||
[[autodoc]] FlaxBertForCausalLM
|
||||
- __call__
|
||||
|
||||
## FlaxBertForMaskedLM
|
||||
|
||||
[[autodoc]] FlaxBertForMaskedLM
|
||||
- __call__
|
||||
|
||||
## FlaxBertForNextSentencePrediction
|
||||
|
||||
[[autodoc]] FlaxBertForNextSentencePrediction
|
||||
- __call__
|
||||
|
||||
## FlaxBertForSequenceClassification
|
||||
|
||||
[[autodoc]] FlaxBertForSequenceClassification
|
||||
- __call__
|
||||
|
||||
## FlaxBertForMultipleChoice
|
||||
|
||||
[[autodoc]] FlaxBertForMultipleChoice
|
||||
- __call__
|
||||
|
||||
## FlaxBertForTokenClassification
|
||||
|
||||
[[autodoc]] FlaxBertForTokenClassification
|
||||
- __call__
|
||||
|
||||
## FlaxBertForQuestionAnswering
|
||||
|
||||
[[autodoc]] FlaxBertForQuestionAnswering
|
||||
- __call__
|
||||
|
||||
## Bert specific outputs
|
||||
|
||||
[[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput
|
||||
|
||||
[[autodoc]] models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
|
||||
|
||||
[[autodoc]] models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput
|
435
examples/3D_parallel.py
Normal file
435
examples/3D_parallel.py
Normal file
@ -0,0 +1,435 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""":
|
||||
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
|
||||
|
||||
Usage:
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
export CUDA_VISIBLE_DEVICES=5,6,7
|
||||
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
|
||||
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
|
||||
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
|
||||
DP_SIZE=2 CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=8 examples/3D_parallel.py
|
||||
|
||||
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
|
||||
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
|
||||
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
|
||||
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=localhost:29504 examples/3D_parallel.py
|
||||
ocalhost:29504 test_train.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
# set_rotate_method("alltoall") # CP rotate shards using all-to-all
|
||||
|
||||
|
||||
def main():
|
||||
tp_size = int(os.environ.get("TP_SIZE", 1))
|
||||
dp_size = int(os.environ.get("DP_SIZE", 1))
|
||||
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
|
||||
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
|
||||
# sdpa_backend = SDPBackend.MATH # For CP
|
||||
global_batch_size = 8 # Desired global batch size
|
||||
seq_len = 1024 # Sequence length
|
||||
num_train_steps = 10000 # Number of training steps
|
||||
LR = 1e-5
|
||||
model_name = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
# model_name = "unsloth/Llama-3.2-1B"
|
||||
|
||||
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
|
||||
|
||||
# Initialize distributed environment
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
dist.init_process_group("nccl")
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
assert world_size == tp_size * dp_size * cp_size, (
|
||||
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
|
||||
)
|
||||
|
||||
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
|
||||
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
|
||||
tp_mesh = world_mesh["tp"]
|
||||
dp_mesh = world_mesh["dp"]
|
||||
cp_mesh = world_mesh["cp"]
|
||||
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
|
||||
logger.info(f"Created DeviceMesh: {world_mesh}")
|
||||
logger.info(
|
||||
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
|
||||
)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
wandb.init(
|
||||
project="tp_dp_test",
|
||||
config={
|
||||
"tp_size": tp_size,
|
||||
"dp_size": dp_size,
|
||||
"cp_size": cp_size,
|
||||
"global_batch_size": global_batch_size,
|
||||
"model_name": model_name,
|
||||
"dataset": "roneneldan/TinyStories-1M",
|
||||
"seq_len": seq_len,
|
||||
"lr": LR,
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
|
||||
if model_name == "unsloth/Llama-3.2-1B"
|
||||
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
|
||||
)
|
||||
logger.info("Wandb initialized.")
|
||||
# Log the current file to wandb
|
||||
wandb.save("test_train.py")
|
||||
|
||||
# Load model and tokenizer
|
||||
logger.info(f"Loading model and tokenizer from {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_mesh=tp_mesh if dist.is_initialized() else None,
|
||||
tp_plan="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
logger.info(f"Using device: {device} for non-model tensors")
|
||||
use_ddp = False
|
||||
if dist.is_initialized() and dp_mesh.size() > 1:
|
||||
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
|
||||
use_ddp = True
|
||||
pass
|
||||
|
||||
model.train()
|
||||
|
||||
logger.info("Loading TinyStories dataset...")
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Tokenize the text without padding
|
||||
tokenized_batch = tokenizer(
|
||||
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
|
||||
)
|
||||
# Set labels to be the same as input_ids for Causal LM
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
|
||||
|
||||
# Create packed sequences
|
||||
def create_packed_sequences(examples):
|
||||
# Flatten all sequences
|
||||
all_tokens = []
|
||||
for input_ids in examples["input_ids"]:
|
||||
all_tokens.extend(input_ids)
|
||||
|
||||
# Split into sequences of seq_len + 1 (for input + label)
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
end_idx = start_idx + (seq_len + 1)
|
||||
# Get the full sequence
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
# For input_ids, remove the last token
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
# For labels, remove the first token
|
||||
packed_labels.append(full_sequence[1:])
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
|
||||
# Apply packing to the dataset
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=1000, # Process in batches for efficiency
|
||||
num_proc=60,
|
||||
)
|
||||
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
|
||||
|
||||
# Shuffle the packed dataset
|
||||
packed_dataset = packed_dataset.shuffle(seed=42)
|
||||
logger.info("Packed dataset shuffled")
|
||||
|
||||
# Calculate local batch size
|
||||
if dist.is_initialized():
|
||||
assert global_batch_size % dp_mesh.size() == 0, (
|
||||
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
|
||||
)
|
||||
local_batch_size = global_batch_size // dp_mesh.size()
|
||||
else:
|
||||
local_batch_size = global_batch_size
|
||||
|
||||
logger.info(
|
||||
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
|
||||
)
|
||||
|
||||
# Simple collate function since sequences are already packed
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
if dist.is_initialized():
|
||||
sampler = DistributedSampler(
|
||||
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
packed_dataset,
|
||||
batch_size=local_batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=True,
|
||||
)
|
||||
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
|
||||
|
||||
# Training loop
|
||||
logger.info(f"Starting training for {num_train_steps} steps...")
|
||||
model.train()
|
||||
step = 0
|
||||
while step < num_train_steps:
|
||||
for batch in dataloader:
|
||||
if step >= num_train_steps:
|
||||
break # Exit loop if max steps reached
|
||||
|
||||
# Move batch to appropriate device
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Add position_ids to batch before CP sharding
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
batch["position_ids"] = position_ids
|
||||
from torch.distributed.tensor.experimental._attention import _cp_options
|
||||
|
||||
_cp_options.enable_load_balance = False
|
||||
|
||||
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
|
||||
cp_context = (
|
||||
nullcontext()
|
||||
if cp_mesh.size() == 1
|
||||
else context_parallel(
|
||||
cp_mesh,
|
||||
buffers=[
|
||||
batch["input_ids"],
|
||||
batch["labels"],
|
||||
batch["position_ids"],
|
||||
],
|
||||
buffer_seq_dims=[1, 1, 1],
|
||||
)
|
||||
)
|
||||
with cp_context:
|
||||
# Pop labels from batch before model forward pass
|
||||
labels = batch.pop("labels")
|
||||
outputs = model(**batch) # [mbs, seq_len/cp]
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits
|
||||
|
||||
# Compute loss with shifted labels
|
||||
loss = model.loss_function(
|
||||
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
# all reduce grads across dp_cp if applicable
|
||||
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
|
||||
|
||||
if hasattr(model, "clip_grad_norm_"):
|
||||
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) # TODO: fix reported gradnorm
|
||||
else:
|
||||
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
|
||||
assert len(list(model.parameters())) > 5, "No parameters found in model. Probably DDP bug.."
|
||||
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
|
||||
|
||||
optimizer.step()
|
||||
# allreduce loss across cp_dp before logging
|
||||
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
|
||||
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
|
||||
current_loss = loss.item()
|
||||
|
||||
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
|
||||
)
|
||||
wandb.log(
|
||||
{
|
||||
"train/loss": current_loss,
|
||||
"train/gradnorm": gradnorm,
|
||||
"step": step,
|
||||
"lr": LR,
|
||||
"GBS": global_batch_size,
|
||||
}
|
||||
)
|
||||
|
||||
step += 1 # Increment step count
|
||||
|
||||
logger.info("Training loop finished.")
|
||||
|
||||
# Save model using DCP (only if distributed)
|
||||
if dist.is_initialized():
|
||||
state_dict = {"app": AppState(model, optimizer)}
|
||||
dcp.save(
|
||||
state_dict=state_dict,
|
||||
checkpoint_id=CHECKPOINT_DIR,
|
||||
)
|
||||
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
|
||||
else:
|
||||
# Fallback to regular save for non-distributed case
|
||||
save_dir = "test_model_nondist"
|
||||
model.save_pretrained(save_dir, safe_serialization=False)
|
||||
tokenizer.save_pretrained(save_dir) # Save tokenizer too
|
||||
logger.info(f"Saved model to {save_dir}")
|
||||
|
||||
dist.destroy_process_group()
|
||||
logger.info("Cleaned up distributed process group")
|
||||
# Finish wandb run on rank 0
|
||||
if dist.get_rank() == 0:
|
||||
wandb.finish()
|
||||
logger.info("Wandb run finished.")
|
||||
|
||||
|
||||
def all_reduce_grads(model, world_mesh, use_ddp):
|
||||
"""All reduce gradients across dp_cp if applicable."""
|
||||
cp_mesh = world_mesh["cp"]
|
||||
if use_ddp:
|
||||
# DDP/FSDP takes care of syncing grads
|
||||
mesh = cp_mesh
|
||||
else:
|
||||
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
|
||||
if dist.is_initialized() and mesh.size() > 1:
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None:
|
||||
# Workaround for cross-mesh communication limitation with DTensor gradients
|
||||
if isinstance(param.grad, DTensor):
|
||||
local_grad = param.grad.to_local()
|
||||
# Ensure grad requires grad for inplace modification checks (might not be needed)
|
||||
# local_grad = local_grad.detach().requires_grad_(True)
|
||||
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
|
||||
local_grad = local_grad / mesh.size()
|
||||
# Assign averaged grad back - need careful handling if DTensor structure is complex
|
||||
# This simple assignment might work if the grad structure matches param structure
|
||||
param.grad = DTensor.from_local(
|
||||
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
|
||||
)
|
||||
else:
|
||||
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
|
||||
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
|
||||
|
||||
|
||||
class AppState(Stateful):
|
||||
"""Wrapper for checkpointing the Application State including model and optimizer."""
|
||||
|
||||
def __init__(self, model, optimizer=None):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
def state_dict(self):
|
||||
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
|
||||
return {"model": model_state_dict, "optim": optimizer_state_dict}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
set_state_dict(
|
||||
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
|
||||
)
|
||||
|
||||
|
||||
def clip_grad_norm_(
|
||||
parameters: Iterable[torch.Tensor],
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Clip the gradient norm of an iterable of parameters.
|
||||
"""
|
||||
# Filter out parameters with no gradients
|
||||
parameters = [p for p in parameters if p.grad is not None]
|
||||
assert len(parameters) > 0, "No parameters with gradients found"
|
||||
|
||||
# Calculate total norm
|
||||
if norm_type == float("inf"):
|
||||
total_norm = max(p.grad.detach().abs().max() for p in parameters)
|
||||
else:
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
|
||||
|
||||
# Convert DTensor to local tensor if needed
|
||||
if isinstance(total_norm, DTensor):
|
||||
total_norm = total_norm.full_tensor()
|
||||
|
||||
# Clip gradients
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1:
|
||||
for p in parameters:
|
||||
p.grad.detach().mul_(clip_coef)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
793
examples/pytorch/3d_parallel_checks.py
Normal file
793
examples/pytorch/3d_parallel_checks.py
Normal file
@ -0,0 +1,793 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""":
|
||||
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
|
||||
|
||||
Usage:
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
||||
export CUDA_VISIBLE_DEVICES=5,6,7
|
||||
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
|
||||
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
|
||||
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
|
||||
|
||||
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
|
||||
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
|
||||
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
|
||||
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=l
|
||||
ocalhost:29504 test_train.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
||||
from torch.distributed.checkpoint.stateful import Stateful
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
from torch.distributed.tensor import DTensor
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.utils.data import DataLoader, default_collate
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
ignore_sanity_checks = int(os.environ.get("IGNORE_SANITY", 0)) == 1
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
# set_rotate_method("alltoall") # rotate shards using all-to-all
|
||||
|
||||
|
||||
def main():
|
||||
tp_size = int(os.environ.get("TP_SIZE", 1))
|
||||
dp_size = int(os.environ.get("DP_SIZE", 4))
|
||||
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
|
||||
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
|
||||
# sdpa_backend = SDPBackend.MATH # For CP
|
||||
global_batch_size = 8 # Desired global batch size
|
||||
seq_len = 1024 # Sequence length
|
||||
num_train_steps = 10000 # Number of training steps
|
||||
LR = 1e-5
|
||||
model_name = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
# model_name = "unsloth/Llama-3.2-1B"
|
||||
|
||||
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
|
||||
|
||||
# Initialize distributed environment
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
dist.init_process_group("nccl")
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
assert world_size == tp_size * dp_size * cp_size, (
|
||||
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
|
||||
)
|
||||
|
||||
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
|
||||
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
|
||||
tp_mesh = world_mesh["tp"]
|
||||
dp_mesh = world_mesh["dp"]
|
||||
cp_mesh = world_mesh["cp"]
|
||||
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
|
||||
logger.info(f"Created DeviceMesh: {world_mesh}")
|
||||
logger.info(
|
||||
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
|
||||
)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
wandb.init(
|
||||
project="tp_dp_test",
|
||||
config={
|
||||
"tp_size": tp_size,
|
||||
"dp_size": dp_size,
|
||||
"cp_size": cp_size,
|
||||
"global_batch_size": global_batch_size,
|
||||
"model_name": model_name,
|
||||
"dataset": "roneneldan/TinyStories-1M",
|
||||
"seq_len": seq_len,
|
||||
"lr": LR,
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
|
||||
if model_name == "unsloth/Llama-3.2-1B"
|
||||
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
|
||||
)
|
||||
logger.info(f"ignore_sanity_checks is set to: {ignore_sanity_checks}")
|
||||
logger.info("Wandb initialized.")
|
||||
# Log the current file to wandb
|
||||
wandb.save("test_train.py")
|
||||
|
||||
else:
|
||||
logger.info("Running in non-distributed mode. DeviceMesh not applicable.")
|
||||
rank = 0
|
||||
world_size = 1
|
||||
local_rank = 0
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
wandb.init(
|
||||
project="tp_dp_test",
|
||||
config={
|
||||
"tp_size": 1,
|
||||
"dp_size": 1,
|
||||
"global_batch_size": global_batch_size,
|
||||
"model_name": model_name,
|
||||
"dataset": "roneneldan/TinyStories-1M",
|
||||
"seq_len": seq_len,
|
||||
},
|
||||
name="llama_tp1_dp1_nondist" if model_name == "unsloth/Llama-3.2-1B" else "tp1_dp1_nondist",
|
||||
)
|
||||
logger.info("Wandb initialized for non-distributed run.")
|
||||
|
||||
# Load model and tokenizer
|
||||
logger.info(f"Loading model and tokenizer from {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_mesh=tp_mesh if dist.is_initialized() else None,
|
||||
tp_plan="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
|
||||
|
||||
if dist.is_initialized():
|
||||
assert model.config.num_key_value_heads % tp_mesh.size() == 0, (
|
||||
f"num_key_value_heads={model.config.num_key_value_heads} must be divisible by tp_size={tp_mesh.size()}"
|
||||
)
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
logger.info(f"Using device: {device} for non-model tensors")
|
||||
use_ddp = False
|
||||
if dist.is_initialized() and dp_mesh.size() > 1:
|
||||
# FSDP1
|
||||
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
|
||||
# FSDP2
|
||||
# for transformer_block in model.model.layers:
|
||||
# fully_shard(transformer_block, mesh=dp_mesh, reshard_after_forward=False)
|
||||
# fully_shard(model.model, mesh=dp_mesh, reshard_after_forward=False)
|
||||
# DDP
|
||||
# replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
||||
# assert len(list(model.parameters()))>5, "No parameters found in model. Probably DDP/FSDP bug.." # TODO: we should be cautious abt using model.parameters()
|
||||
use_ddp = True
|
||||
|
||||
model.train()
|
||||
assert len(list(model.parameters())) > 0, "No parameters found in model. Probably DDP bug.."
|
||||
assert len([p for p in model.parameters() if p.requires_grad]) > 0, (
|
||||
"No gradients found in model. Probably DDP bug.."
|
||||
)
|
||||
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# assert model is replicated across all dp
|
||||
for name, param in model.named_parameters():
|
||||
sanity_check_tensor_sync(param, dp_mesh)
|
||||
|
||||
# assert model is different across tp (only for sharded params)
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(param, DTensor) and param.placements[0].is_shard():
|
||||
# Only check sharded parameters for non-sync across TP
|
||||
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
|
||||
elif isinstance(param, DTensor) and param.placements[0].is_replicate():
|
||||
# Replicated parameters should be the same across TP
|
||||
sanity_check_tensor_sync(param, tp_mesh)
|
||||
|
||||
# assert model is replicated across cp
|
||||
for name, param in model.named_parameters():
|
||||
sanity_check_tensor_sync(param, cp_mesh)
|
||||
|
||||
# Load and preprocess TinyStories dataset
|
||||
logger.info("Loading TinyStories dataset...")
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Tokenize the text without padding
|
||||
tokenized_batch = tokenizer(
|
||||
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
|
||||
)
|
||||
# Set labels to be the same as input_ids for Causal LM
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
|
||||
|
||||
# Create packed sequences
|
||||
def create_packed_sequences(examples):
|
||||
# Flatten all sequences
|
||||
all_tokens = []
|
||||
for input_ids in examples["input_ids"]:
|
||||
all_tokens.extend(input_ids)
|
||||
|
||||
# Split into sequences of seq_len + 1 (for input + label)
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
end_idx = start_idx + (seq_len + 1)
|
||||
# Get the full sequence
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
# For input_ids, remove the last token
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
# For labels, remove the first token
|
||||
packed_labels.append(full_sequence[1:])
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
|
||||
# Apply packing to the dataset
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=1000, # Process in batches for efficiency
|
||||
num_proc=60,
|
||||
)
|
||||
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
|
||||
|
||||
# Shuffle the packed dataset
|
||||
packed_dataset = packed_dataset.shuffle(seed=42)
|
||||
logger.info("Packed dataset shuffled")
|
||||
|
||||
# Calculate local batch size
|
||||
if dist.is_initialized():
|
||||
assert global_batch_size % dp_mesh.size() == 0, (
|
||||
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
|
||||
)
|
||||
local_batch_size = global_batch_size // dp_mesh.size()
|
||||
else:
|
||||
local_batch_size = global_batch_size
|
||||
|
||||
logger.info(
|
||||
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
|
||||
)
|
||||
|
||||
# Simple collate function since sequences are already packed
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
if dist.is_initialized():
|
||||
sampler = DistributedSampler(
|
||||
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
dataloader = DataLoader(
|
||||
packed_dataset,
|
||||
batch_size=local_batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
|
||||
|
||||
# Training loop
|
||||
logger.info(f"Starting training for {num_train_steps} steps...")
|
||||
model.train()
|
||||
step = 0
|
||||
while step < num_train_steps:
|
||||
for batch in dataloader:
|
||||
if step >= num_train_steps:
|
||||
break # Exit loop if max steps reached
|
||||
|
||||
# Move batch to appropriate device
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
# Sanity checks for batch distribution (only if distributed)
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# check batch is same across all tp
|
||||
sanity_check_tensor_sync(batch["input_ids"], tp_mesh)
|
||||
# check batch is different across dp
|
||||
sanity_check_tensor_sync(batch["input_ids"], dp_mesh, not_sync=True)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Add position_ids to batch before CP sharding
|
||||
batch_size = batch["input_ids"].shape[0]
|
||||
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
batch["position_ids"] = position_ids
|
||||
from torch.distributed.tensor.experimental._attention import _cp_options
|
||||
|
||||
_cp_options.enable_load_balance = False
|
||||
|
||||
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
|
||||
cp_context = (
|
||||
nullcontext()
|
||||
if cp_mesh.size() == 1
|
||||
else context_parallel(
|
||||
cp_mesh,
|
||||
buffers=[
|
||||
batch["input_ids"],
|
||||
batch["labels"],
|
||||
batch["position_ids"],
|
||||
], # TODO: need to add attention mask
|
||||
buffer_seq_dims=[1, 1, 1],
|
||||
)
|
||||
)
|
||||
with cp_context:
|
||||
# Pop labels from batch before model forward pass
|
||||
labels = batch.pop("labels")
|
||||
outputs = model(**batch) # [mbs, seq_len/cp]
|
||||
loss = outputs.loss
|
||||
logits = outputs.logits
|
||||
|
||||
# Compute loss with shifted labels
|
||||
loss = model.loss_function(
|
||||
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
|
||||
)
|
||||
|
||||
# Sanity checks for logits
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# sanity_check_tensor_sync(logits, tp_mesh) # TODO: only true without sequence parallel
|
||||
sanity_check_tensor_sync(logits, dp_mesh, not_sync=True)
|
||||
sanity_check_tensor_sync(logits, cp_mesh, not_sync=True)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# all reduce grads across dp_cp if applicable
|
||||
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
|
||||
|
||||
# Sanity checks for gradients (only if distributed)
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# check grads are not same across all tp (for sharded grads)
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and isinstance(param.grad, DTensor):
|
||||
if param.grad.placements[0].is_shard():
|
||||
sanity_check_tensor_sync(param.grad, tp_mesh, not_sync=True)
|
||||
elif param.grad.placements[0].is_replicate():
|
||||
sanity_check_tensor_sync(param.grad, tp_mesh)
|
||||
# check grads are same across dp
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and dp_mesh.size() > 1:
|
||||
sanity_check_tensor_sync(param.grad, dp_mesh)
|
||||
# check grads are same across cp
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None and cp_mesh.size() > 1:
|
||||
sanity_check_tensor_sync(param.grad, cp_mesh)
|
||||
|
||||
# Calculate gradient norm and clip gradients
|
||||
if hasattr(model, "clip_grad_norm_"):
|
||||
# when using FSDP or DDP, model.parameters() doesn't work
|
||||
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0)
|
||||
else:
|
||||
assert len(list(model.parameters())) > 2, "No parameters found in model. Probably DDP bug.."
|
||||
assert len([p for p in model.parameters() if p.requires_grad]) > 2, (
|
||||
"No gradients found in model. Probably DDP bug.."
|
||||
)
|
||||
assert len([p for p in model.parameters() if p.grad is not None]) > 2, (
|
||||
"No gradients found in model. Probably DDP bug.."
|
||||
)
|
||||
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
|
||||
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
|
||||
|
||||
optimizer.step()
|
||||
# Sanity checks for updated model parameters (only if distributed)
|
||||
if dist.is_initialized() and not ignore_sanity_checks:
|
||||
# check updated model is different across all tp (for sharded params)
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(param, DTensor):
|
||||
if param.placements[0].is_shard():
|
||||
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
|
||||
elif param.placements[0].is_replicate():
|
||||
sanity_check_tensor_sync(param, tp_mesh)
|
||||
# check updated model is same across dp
|
||||
for name, param in model.named_parameters():
|
||||
sanity_check_tensor_sync(param, dp_mesh)
|
||||
# check updated model is same across cp
|
||||
for name, param in model.named_parameters():
|
||||
sanity_check_tensor_sync(param, cp_mesh)
|
||||
|
||||
# allreduce loss across cp_dp before logging
|
||||
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
|
||||
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
|
||||
current_loss = loss.item()
|
||||
|
||||
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
logger.info(
|
||||
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
|
||||
)
|
||||
wandb.log(
|
||||
{
|
||||
"train/loss": current_loss,
|
||||
"train/gradnorm": gradnorm,
|
||||
"step": step,
|
||||
"lr": LR,
|
||||
"GBS": global_batch_size,
|
||||
}
|
||||
)
|
||||
|
||||
step += 1 # Increment step count
|
||||
|
||||
logger.info("Training loop finished.")
|
||||
|
||||
# Save model using DCP (only if distributed)
|
||||
if dist.is_initialized():
|
||||
state_dict = {"app": AppState(model, optimizer)}
|
||||
dcp.save(
|
||||
state_dict=state_dict,
|
||||
checkpoint_id=CHECKPOINT_DIR,
|
||||
)
|
||||
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
|
||||
else:
|
||||
# Fallback to regular save for non-distributed case
|
||||
save_dir = "test_model_nondist"
|
||||
model.save_pretrained(save_dir, safe_serialization=False)
|
||||
tokenizer.save_pretrained(save_dir) # Save tokenizer too
|
||||
logger.info(f"Saved model to {save_dir}")
|
||||
|
||||
# Example of loading the checkpoint (only if distributed)
|
||||
if dist.is_initialized():
|
||||
# Create a new model instance
|
||||
logger.info("Creating new model instance for verification")
|
||||
new_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_mesh=tp_mesh,
|
||||
torch_dtype=torch.bfloat16, # Use same dtype
|
||||
)
|
||||
new_optimizer = optim.AdamW(new_model.parameters(), lr=LR)
|
||||
|
||||
# Load checkpoint into new model
|
||||
state_dict = {"app": AppState(new_model, new_optimizer)}
|
||||
dcp.load(
|
||||
state_dict=state_dict,
|
||||
checkpoint_id=CHECKPOINT_DIR,
|
||||
)
|
||||
logger.info("Loaded checkpoint into new model")
|
||||
|
||||
# Verify model weights match
|
||||
logger.info("Verifying model weights match...")
|
||||
for (name1, param1), (name2, param2) in zip(model.named_parameters(), new_model.named_parameters()):
|
||||
torch.testing.assert_close(
|
||||
param1.to_local(),
|
||||
param2.to_local(),
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
msg=f"Weights mismatch in {name1} vs {name2}",
|
||||
)
|
||||
|
||||
# Verify optimizer states match
|
||||
logger.info("Verifying optimizer states match...")
|
||||
for name1, state1 in optimizer.state_dict().items():
|
||||
state2 = new_optimizer.state_dict()[name1]
|
||||
if name1 == "state":
|
||||
# Compare state dictionaries for each parameter
|
||||
for param_id, param_state1 in state1.items():
|
||||
param_state2 = state2[param_id]
|
||||
# Compare each state component (step, exp_avg, exp_avg_sq)
|
||||
for key, value1 in param_state1.items():
|
||||
value2 = param_state2[key]
|
||||
if isinstance(value1, DTensor):
|
||||
# Convert DTensors to local tensors for comparison
|
||||
torch.testing.assert_close(
|
||||
value1.to_local(),
|
||||
value2.to_local(),
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
|
||||
)
|
||||
else:
|
||||
torch.testing.assert_close(
|
||||
value1,
|
||||
value2,
|
||||
rtol=1e-5,
|
||||
atol=1e-5,
|
||||
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
|
||||
)
|
||||
elif name1 == "param_groups":
|
||||
# Compare param_groups (excluding the actual params list)
|
||||
for i, (group1, group2) in enumerate(zip(state1, state2)):
|
||||
for key in group1:
|
||||
if key != "params": # Skip comparing the params list
|
||||
assert group1[key] == group2[key], f"Param group mismatch in param_groups[{i}][{key}]"
|
||||
|
||||
# Run a forward pass with both models to verify outputs match
|
||||
logger.info("Running forward pass verification...")
|
||||
with torch.no_grad():
|
||||
# Use the last batch for verification
|
||||
batch = {k: v.to(device) for k, v in batch.items()} # Ensure batch is on correct device
|
||||
original_outputs = model(**batch)
|
||||
new_outputs = new_model(**batch)
|
||||
torch.testing.assert_close(
|
||||
original_outputs.logits.to_local(),
|
||||
new_outputs.logits.to_local(),
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
msg="Model outputs do not match!",
|
||||
) # Increased tolerance slightly for bf16
|
||||
|
||||
# Clean up distributed environment and finish wandb run
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
logger.info("Cleaned up distributed process group")
|
||||
# Finish wandb run on rank 0
|
||||
if dist.get_rank() == 0:
|
||||
wandb.finish()
|
||||
logger.info("Wandb run finished.")
|
||||
else:
|
||||
wandb.finish()
|
||||
logger.info("Wandb run finished.")
|
||||
|
||||
|
||||
def all_reduce_grads(model, world_mesh, use_ddp):
|
||||
"""All reduce gradients across dp_cp if applicable."""
|
||||
cp_mesh = world_mesh["cp"]
|
||||
if use_ddp:
|
||||
# DDP takes care of syncing grads
|
||||
mesh = cp_mesh
|
||||
else:
|
||||
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
|
||||
if dist.is_initialized() and mesh.size() > 1:
|
||||
for name, param in model.named_parameters():
|
||||
if param.grad is not None:
|
||||
# Workaround for cross-mesh communication limitation with DTensor gradients
|
||||
if isinstance(param.grad, DTensor):
|
||||
local_grad = param.grad.to_local()
|
||||
# Ensure grad requires grad for inplace modification checks (might not be needed)
|
||||
# local_grad = local_grad.detach().requires_grad_(True)
|
||||
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
|
||||
local_grad = local_grad / mesh.size()
|
||||
# Assign averaged grad back - need careful handling if DTensor structure is complex
|
||||
# This simple assignment might work if the grad structure matches param structure
|
||||
param.grad = DTensor.from_local(
|
||||
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
|
||||
)
|
||||
else:
|
||||
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
|
||||
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
|
||||
|
||||
|
||||
class ContextParallelCollator:
|
||||
"""Collator for context parallel training that splits sequences into chunks."""
|
||||
|
||||
def __init__(self, cp_mesh: Optional[DeviceMesh] = None):
|
||||
self.cp_mesh = cp_mesh
|
||||
|
||||
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
batch = default_collate(batch)
|
||||
if self.cp_mesh is not None and self.cp_mesh.size() > 1:
|
||||
# Get sequence length from the input batch
|
||||
seq_len = batch["input_ids"].shape[1]
|
||||
assert seq_len % self.cp_mesh.size() == 0, (
|
||||
f"Sequence length {seq_len} must be divisible by CP size {self.cp_mesh.size()}"
|
||||
)
|
||||
chunk_size = seq_len // self.cp_mesh.size()
|
||||
cp_rank = self.cp_mesh.get_local_rank()
|
||||
start_idx = cp_rank * chunk_size
|
||||
end_idx = start_idx + chunk_size
|
||||
|
||||
# Keep only the local chunk of the sequence
|
||||
batch["input_ids"] = batch["input_ids"][:, start_idx:end_idx]
|
||||
batch["attention_mask"] = batch["attention_mask"][:, start_idx:end_idx]
|
||||
batch["labels"] = batch["labels"][:, start_idx:end_idx]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class AppState(Stateful):
|
||||
"""Wrapper for checkpointing the Application State including model and optimizer."""
|
||||
|
||||
def __init__(self, model, optimizer=None):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
def state_dict(self):
|
||||
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
|
||||
return {"model": model_state_dict, "optim": optimizer_state_dict}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
set_state_dict(
|
||||
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
|
||||
)
|
||||
|
||||
|
||||
def sanity_check_tensor_sync(
|
||||
tensor: torch.Tensor, mesh: DeviceMesh, rtol: float = 1e-4, atol: float = 1e-4, not_sync: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Verify that a tensor is synchronized (or not synchronized) across all processes in the mesh's process group.
|
||||
Handles both regular tensors and DTensors.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to check for synchronization (can be DTensor)
|
||||
mesh (DeviceMesh): The device mesh containing the process group
|
||||
rtol (float): Relative tolerance for comparison
|
||||
atol (float): Absolute tolerance for comparison
|
||||
not_sync (bool): If True, asserts that tensors are NOT synchronized. If False, asserts they are synchronized.
|
||||
"""
|
||||
if not dist.is_initialized() or mesh.size() == 1:
|
||||
return # No need to check in non-distributed mode
|
||||
|
||||
# Get the process group from the mesh
|
||||
pg = mesh.get_group()
|
||||
|
||||
# Convert DTensor to local tensor if needed
|
||||
if hasattr(tensor, "to_local"):
|
||||
local_tensor = tensor.to_local()
|
||||
else:
|
||||
local_tensor = tensor
|
||||
|
||||
# Gather tensors from all processes
|
||||
world_size = dist.get_world_size(pg)
|
||||
gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
|
||||
dist.all_gather(gathered_tensors, local_tensor, group=pg)
|
||||
|
||||
# Compare each tensor with the first one
|
||||
for i in range(1, world_size):
|
||||
try:
|
||||
torch.testing.assert_close(gathered_tensors[0], gathered_tensors[i], rtol=rtol, atol=atol)
|
||||
except AssertionError as e:
|
||||
if not_sync:
|
||||
continue
|
||||
# # Add detailed debugging for logit synchronization issues
|
||||
# print(f"\nLogit synchronization error between rank 0 and rank {i}:")
|
||||
# print(f"Tensor shape: {gathered_tensors[0].shape}")
|
||||
# print(f"Number of mismatched elements: {(gathered_tensors[0] != gathered_tensors[i]).sum()}")
|
||||
# print(f"Percentage of mismatched elements: {((gathered_tensors[0] != gathered_tensors[i]).sum() / gathered_tensors[0].numel() * 100):.2f}%")
|
||||
|
||||
# # Find the first few mismatches
|
||||
# mismatches = torch.nonzero(gathered_tensors[0] != gathered_tensors[i])
|
||||
# print("\nFirst few mismatches:")
|
||||
# for idx in mismatches[:5]:
|
||||
# idx = tuple(idx.tolist())
|
||||
# print(f"Index {idx}:")
|
||||
# print(f"Rank 0 value: {gathered_tensors[0][idx]}")
|
||||
# print(f"Rank {i} value: {gathered_tensors[i][idx]}")
|
||||
# print(f"Absolute difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx])}")
|
||||
# print(f"Relative difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx]) / max(abs(gathered_tensors[0][idx]), abs(gathered_tensors[i][idx]))}")
|
||||
|
||||
# # Check if differences are systematic (e.g., all positive or negative)
|
||||
# diff = gathered_tensors[0] - gathered_tensors[i]
|
||||
# print(f"\nDifference statistics:")
|
||||
# print(f"Mean difference: {diff.mean()}")
|
||||
# print(f"Std difference: {diff.std()}")
|
||||
# print(f"Max positive difference: {diff.max()}")
|
||||
# print(f"Max negative difference: {diff.min()}")
|
||||
raise e
|
||||
|
||||
|
||||
def clip_grad_norm_(
|
||||
parameters: Iterable[torch.Tensor],
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Clip the gradient norm of an iterable of parameters.
|
||||
"""
|
||||
# Filter out parameters with no gradients
|
||||
parameters = [p for p in parameters if p.grad is not None]
|
||||
assert len(parameters) > 0, "No parameters with gradients found"
|
||||
|
||||
# Calculate total norm
|
||||
if norm_type == float("inf"):
|
||||
total_norm = max(p.grad.detach().abs().max() for p in parameters)
|
||||
else:
|
||||
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
|
||||
|
||||
# Convert DTensor to local tensor if needed
|
||||
if isinstance(total_norm, DTensor):
|
||||
total_norm = total_norm.full_tensor()
|
||||
|
||||
# Clip gradients
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
if clip_coef < 1:
|
||||
for p in parameters:
|
||||
p.grad.detach().mul_(clip_coef)
|
||||
|
||||
return total_norm
|
||||
|
||||
|
||||
def check_params_sync(model_params, original_params):
|
||||
"""
|
||||
Check if original_params are being updated in sync with model parameters.
|
||||
|
||||
Args:
|
||||
model_params: Iterator of model parameters after update
|
||||
original_params: List of original parameters before DDP wrapping
|
||||
"""
|
||||
for mp, op in zip(model_params, original_params):
|
||||
if isinstance(mp, DTensor):
|
||||
mp = mp.to_local()
|
||||
if isinstance(op, DTensor):
|
||||
op = op.to_local()
|
||||
if not torch.allclose(mp.data, op.data, rtol=0, atol=0):
|
||||
raise RuntimeError(f"Parameters out of sync: model param {mp.data} != original param {op.data}")
|
||||
return True
|
||||
|
||||
|
||||
def get_parameters(model: nn.Module) -> Iterable[torch.Tensor]:
|
||||
"""
|
||||
Get all parameters from a model by iterating over its modules.
|
||||
This is an alternative to model.parameters() that works with DTensor models.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to get parameters from
|
||||
|
||||
Returns:
|
||||
Iterable[torch.Tensor]: An iterator over all parameters in the model
|
||||
"""
|
||||
for name, module in model._modules.items():
|
||||
# Look for parameters in module attributes
|
||||
for attr_name, attr in module.__dict__.items():
|
||||
if isinstance(attr, torch.Tensor) and attr.requires_grad:
|
||||
yield attr
|
||||
# Recursively get parameters from submodules
|
||||
for param in get_parameters(module):
|
||||
yield param
|
||||
|
||||
|
||||
def update_model_parameters(model: nn.Module) -> None:
|
||||
"""
|
||||
Update model._parameters using named_modules() to ensure all parameters are properly tracked.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to update parameters for
|
||||
"""
|
||||
# Clear existing parameters
|
||||
model._parameters = {}
|
||||
|
||||
# Add parameters from named_modules
|
||||
for name, module in model.named_modules():
|
||||
# Skip the root module itself
|
||||
if name == "":
|
||||
continue
|
||||
|
||||
# Get the parameter name by removing 'module.' prefix if it exists
|
||||
param_name = name.replace("module.", "")
|
||||
|
||||
# Add weight and bias parameters if they exist
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
model._parameters[f"{param_name}.weight"] = module.weight
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
model._parameters[f"{param_name}.bias"] = module.bias
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -44,7 +44,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
94
examples/pytorch/context_parallel.py
Normal file
94
examples/pytorch/context_parallel.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.loss.loss_utils import ForCausalLMLoss
|
||||
|
||||
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
cp_mesh = init_device_mesh("cuda", (world_size,))
|
||||
rank = torch.distributed.get_node_local_rank()
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
sdpa_backend = SDPBackend.FLASH_ATTENTION
|
||||
|
||||
# prepare inputs
|
||||
batch_size = 1
|
||||
seq_len = 128
|
||||
|
||||
input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device)
|
||||
|
||||
ignore_index = -100
|
||||
# When using CP, we need to use `shift_labels`
|
||||
shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index)
|
||||
shift_labels = shift_labels[..., 1:].contiguous()
|
||||
|
||||
position_ids = (
|
||||
torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1
|
||||
)
|
||||
|
||||
# sync input as they are created randomly
|
||||
dist.broadcast(input_ids, src=0)
|
||||
dist.broadcast(shift_labels, src=0)
|
||||
dist.broadcast(position_ids, src=0)
|
||||
|
||||
# model and optimizer
|
||||
repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
|
||||
|
||||
model.train()
|
||||
model.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# For loss
|
||||
vocab_size = model.config.vocab_size
|
||||
|
||||
# so training could be synced
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
# prepare for CP
|
||||
buffers = (input_ids, shift_labels, position_ids)
|
||||
buffer_seq_dims = (1, 1, 1)
|
||||
# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`.
|
||||
# no_restore_buffers = set(buffers)
|
||||
no_restore_buffers = None
|
||||
|
||||
# run with CP
|
||||
with sdpa_kernel(sdpa_backend):
|
||||
with context_parallel(
|
||||
cp_mesh,
|
||||
buffers=buffers,
|
||||
buffer_seq_dims=buffer_seq_dims,
|
||||
no_restore_buffers=no_restore_buffers,
|
||||
):
|
||||
outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)
|
||||
print(outputs.logits.shape)
|
||||
|
||||
# So far we need to compute `loss` outside `model.forward` when using `shift_labels`
|
||||
# loss = outputs.loss
|
||||
loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)
|
||||
|
||||
# This could be outside `context_parallel` context if `no_restore_buffers` is specified
|
||||
loss.backward()
|
||||
optimizer.step()
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -42,7 +42,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -53,7 +53,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
@ -34,7 +34,6 @@ from transformers import (
|
||||
GPT2Tokenizer,
|
||||
GPTJForCausalLM,
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTTokenizer,
|
||||
OPTForCausalLM,
|
||||
@ -63,7 +62,7 @@ MODEL_CLASSES = {
|
||||
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
|
||||
"gptj": (GPTJForCausalLM, AutoTokenizer),
|
||||
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
||||
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
||||
"llama": (LlamaForCausalLM, AutoTokenizer),
|
||||
"opt": (OPTForCausalLM, GPT2Tokenizer),
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.52.0.dev0")
|
||||
check_min_version("4.53.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
6
setup.py
6
setup.py
@ -125,7 +125,7 @@ _deps = [
|
||||
"jaxlib>=0.4.1,<=0.4.13",
|
||||
"jieba",
|
||||
"jinja2>=3.1.0",
|
||||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
|
||||
"kenlm",
|
||||
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
|
||||
"keras>2.9,<2.16",
|
||||
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
|
||||
@ -315,7 +315,7 @@ extras["audio"] = deps_list(
|
||||
"librosa",
|
||||
"pyctcdecode",
|
||||
"phonemizer",
|
||||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
|
||||
"kenlm",
|
||||
)
|
||||
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
|
||||
extras["speech"] = deps_list("torchaudio") + extras["audio"]
|
||||
@ -451,7 +451,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.52.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.53.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.52.0.dev0"
|
||||
__version__ = "4.53.0.dev0"
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
@ -445,6 +445,7 @@ else:
|
||||
_import_structure["modeling_outputs"] = []
|
||||
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
|
||||
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
|
||||
_import_structure["masking_utils"] = ["AttentionMaskInterface"]
|
||||
_import_structure["optimization"] = [
|
||||
"Adafactor",
|
||||
"get_constant_schedule",
|
||||
@ -914,6 +915,7 @@ if TYPE_CHECKING:
|
||||
TorchExportableModuleWithStaticCache,
|
||||
convert_and_export_with_cache,
|
||||
)
|
||||
from .masking_utils import AttentionMaskInterface
|
||||
from .model_debugging_utils import (
|
||||
model_addition_debugger_context,
|
||||
)
|
||||
|
@ -21,6 +21,104 @@ if is_hqq_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Utility functions for static/sliding cache update logic
|
||||
def _static_cache_update(
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
cache_position: Optional[torch.LongTensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the static cache tensors in place.
|
||||
|
||||
Args:
|
||||
k_cache (`torch.Tensor`): The key cache tensor to update.
|
||||
v_cache (`torch.Tensor`): The value cache tensor to update.
|
||||
key_states (`torch.Tensor`): The new key states to add.
|
||||
value_states (`torch.Tensor`): The new value states to add.
|
||||
cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted.
|
||||
If None, the entire cache is overwritten (prefill).
|
||||
|
||||
Returns:
|
||||
Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place).
|
||||
"""
|
||||
if cache_position is None:
|
||||
# Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
|
||||
k_cache.copy_(key_states)
|
||||
v_cache.copy_(value_states)
|
||||
else:
|
||||
# Generation phase. Update specific positions.
|
||||
# Use index_copy_ for in-place update (compile-friendly).
|
||||
try:
|
||||
k_cache.index_copy_(2, cache_position, key_states)
|
||||
v_cache.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# Fallback for devices like MPS where index_copy_ might not be supported.
|
||||
k_cache[:, :, cache_position] = key_states
|
||||
v_cache[:, :, cache_position] = value_states
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
def _sliding_cache_update(
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
cache_position: torch.LongTensor,
|
||||
max_cache_len: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Updates the sliding window cache tensors, returning the potentially modified tensors.
|
||||
|
||||
Args:
|
||||
k_cache (`torch.Tensor`): The key cache tensor to update.
|
||||
v_cache (`torch.Tensor`): The value cache tensor to update.
|
||||
key_states (`torch.Tensor`): The new key states to add.
|
||||
value_states (`torch.Tensor`): The new value states to add.
|
||||
cache_position (`torch.LongTensor`): The position indices where the new states should be inserted.
|
||||
max_cache_len (`int`): The maximum length of the sliding window cache.
|
||||
|
||||
Returns:
|
||||
Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update.
|
||||
For prefill > window, these are the full input states.
|
||||
Otherwise, they are the updated cache tensors.
|
||||
"""
|
||||
# Handle prefill phase when prompt length > sliding_window_size
|
||||
if cache_position.shape[0] > max_cache_len:
|
||||
new_k = key_states[:, :, -max_cache_len:, :]
|
||||
new_v = value_states[:, :, -max_cache_len:, :]
|
||||
k_cache.copy_(new_k)
|
||||
v_cache.copy_(new_v)
|
||||
return key_states, value_states
|
||||
|
||||
# Sliding window logic for generation phase or prefill < window
|
||||
slicing = torch.arange(max_cache_len, device=value_states.device)
|
||||
current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
|
||||
to_shift = current_seq_len > max_cache_len
|
||||
indices = (slicing + to_shift.sum()) % max_cache_len
|
||||
|
||||
k_out_shifted = k_cache[:, :, indices]
|
||||
v_out_shifted = v_cache[:, :, indices]
|
||||
|
||||
# Clamp cache_position to determine the *target index* within the shifted cache view
|
||||
update_position = cache_position.clamp(min=0, max=max_cache_len - 1)
|
||||
|
||||
try:
|
||||
k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
|
||||
v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
|
||||
except NotImplementedError:
|
||||
# Fallback for MPS: clone and modify the clone
|
||||
k_out_updated = k_out_shifted.clone()
|
||||
v_out_updated = v_out_shifted.clone()
|
||||
k_out_updated[:, :, update_position] = key_states
|
||||
v_out_updated[:, :, update_position] = value_states
|
||||
|
||||
k_cache.copy_(k_out_updated)
|
||||
v_cache.copy_(v_out_updated)
|
||||
return k_out_updated, v_out_updated
|
||||
|
||||
|
||||
class Cache:
|
||||
"""
|
||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||
@ -98,6 +196,18 @@ class Cache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
query_length = cache_position.shape[0]
|
||||
past_seen_tokens = self.get_seq_length()
|
||||
kv_length = query_length + past_seen_tokens
|
||||
return kv_length, 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
@ -986,8 +1096,6 @@ class SinkCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
is_sliding = True
|
||||
|
||||
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
||||
super().__init__()
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
@ -1264,28 +1372,16 @@ class StaticCache(Cache):
|
||||
"""
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
if cache_position is None:
|
||||
k_out.copy_(key_states)
|
||||
v_out.copy_(value_states)
|
||||
else:
|
||||
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
|
||||
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
|
||||
# operation, that avoids copies and uses less memory.
|
||||
try:
|
||||
k_out.index_copy_(2, cache_position, key_states)
|
||||
v_out.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
return k_out, v_out
|
||||
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||
return _static_cache_update(
|
||||
self.key_cache[layer_idx],
|
||||
self.value_cache[layer_idx],
|
||||
key_states,
|
||||
value_states,
|
||||
cache_kwargs.get("cache_position"),
|
||||
)
|
||||
|
||||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||
"""Returns the sequence length of the cached states that were seen by the model."""
|
||||
@ -1304,6 +1400,16 @@ class StaticCache(Cache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
kv_length = self.get_max_cache_shape()
|
||||
return kv_length, 0
|
||||
|
||||
|
||||
class SlidingWindowCache(StaticCache):
|
||||
"""
|
||||
@ -1314,7 +1420,7 @@ class SlidingWindowCache(StaticCache):
|
||||
|
||||
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
|
||||
|
||||
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
|
||||
indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window
|
||||
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
|
||||
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
|
||||
@ -1360,7 +1466,6 @@ class SlidingWindowCache(StaticCache):
|
||||
```
|
||||
"""
|
||||
|
||||
is_sliding = True
|
||||
is_compileable = True
|
||||
|
||||
def __init__(
|
||||
@ -1379,6 +1484,7 @@ class SlidingWindowCache(StaticCache):
|
||||
"config and it's not set to None."
|
||||
)
|
||||
max_cache_len = min(config.sliding_window, max_cache_len)
|
||||
self.sliding_window = config.sliding_window
|
||||
super().__init__(
|
||||
config=config,
|
||||
max_batch_size=max_batch_size,
|
||||
@ -1398,46 +1504,21 @@ class SlidingWindowCache(StaticCache):
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
|
||||
if cache_position.shape[0] >= self.max_cache_len:
|
||||
k_out = key_states[:, :, -self.max_cache_len :, :]
|
||||
v_out = value_states[:, :, -self.max_cache_len :, :]
|
||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
||||
return key_states, value_states
|
||||
if cache_position is None:
|
||||
raise ValueError("`cache_position` must be provided for SlidingWindowCache.")
|
||||
|
||||
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||
to_shift = cache_position > self.max_cache_len - 1
|
||||
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
|
||||
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
|
||||
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
try:
|
||||
k_out.index_copy_(2, cache_position, key_states)
|
||||
v_out.index_copy_(2, cache_position, value_states)
|
||||
except NotImplementedError:
|
||||
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
|
||||
return k_out, v_out
|
||||
return _sliding_cache_update(
|
||||
self.key_cache[layer_idx],
|
||||
self.value_cache[layer_idx],
|
||||
key_states,
|
||||
value_states,
|
||||
cache_position,
|
||||
self.max_cache_len,
|
||||
)
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
return self.max_cache_len
|
||||
@ -1448,6 +1529,21 @@ class SlidingWindowCache(StaticCache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
query_length = cache_position.shape[0]
|
||||
first_cache_position = cache_position[0]
|
||||
# torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor
|
||||
kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
|
||||
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
|
||||
kv_length = max(query_length, self.get_max_cache_shape())
|
||||
return kv_length, kv_offset
|
||||
|
||||
|
||||
class EncoderDecoderCache(Cache):
|
||||
"""
|
||||
@ -1680,12 +1776,13 @@ class HybridCache(Cache):
|
||||
super().__init__()
|
||||
if not hasattr(config, "sliding_window") or config.sliding_window is None:
|
||||
raise ValueError(
|
||||
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
|
||||
"Setting `cache_implementation` to 'hybrid' requires the model config supporting "
|
||||
"sliding window attention, please check if there is a `sliding_window` field in the model "
|
||||
"config and it's not set to None."
|
||||
)
|
||||
self.max_cache_len = max_cache_len
|
||||
self._sliding_window_max_len = min(config.sliding_window, max_cache_len)
|
||||
self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings
|
||||
# Sliding layers can't be larger than the overall max cache len
|
||||
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
|
||||
self.max_batch_size = max_batch_size
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
self.head_dim = (
|
||||
@ -1694,22 +1791,22 @@ class HybridCache(Cache):
|
||||
|
||||
self._dtype = dtype
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
else config.num_key_value_heads
|
||||
)
|
||||
|
||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||
self.is_sliding = torch.tensor(
|
||||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
||||
)
|
||||
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
||||
if hasattr(config, "layer_types"):
|
||||
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
|
||||
else:
|
||||
self.is_sliding = [False] * config.num_hidden_layers
|
||||
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (
|
||||
self.max_batch_size,
|
||||
self.num_key_value_heads,
|
||||
self._sliding_window_max_len,
|
||||
self.head_dim,
|
||||
)
|
||||
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
|
||||
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
|
||||
self.sliding_window = min(config.sliding_window, max_cache_len)
|
||||
device = torch.device(device) if device is not None else None
|
||||
for i in range(config.num_hidden_layers):
|
||||
if layer_device_map is not None:
|
||||
@ -1718,7 +1815,7 @@ class HybridCache(Cache):
|
||||
layer_device = device
|
||||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache.
|
||||
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
|
||||
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
@ -1726,42 +1823,6 @@ class HybridCache(Cache):
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
if cache_position.shape[0] >= max_cache_len:
|
||||
k_out = key_states[:, :, -max_cache_len:, :]
|
||||
v_out = value_states[:, :, -max_cache_len:, :]
|
||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
# we should return the whole states instead of k_out, v_out to take the whole prompt
|
||||
# into consideration when building kv cache instead of just throwing away tokens outside of the window
|
||||
return key_states, value_states
|
||||
|
||||
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||
to_shift = cache_position > max_cache_len - 1
|
||||
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
self.key_cache[layer_idx] += k_out
|
||||
self.value_cache[layer_idx] += v_out
|
||||
return k_out, v_out
|
||||
|
||||
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
k_out[:, :, cache_position] = key_states
|
||||
v_out[:, :, cache_position] = value_states
|
||||
|
||||
self.key_cache[layer_idx] = k_out
|
||||
self.value_cache[layer_idx] = v_out
|
||||
return k_out, v_out
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
@ -1772,7 +1833,10 @@ class HybridCache(Cache):
|
||||
if cache_kwargs is None:
|
||||
cache_kwargs = {}
|
||||
cache_position = cache_kwargs.get("cache_position")
|
||||
sliding_window = cache_kwargs.get("sliding_window")
|
||||
if cache_position is None:
|
||||
raise ValueError("`cache_position` must be provided for HybridCache.")
|
||||
|
||||
is_sliding_layer = self.is_sliding[layer_idx]
|
||||
|
||||
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
|
||||
# when the cache is initialized in the forward pass (e.g. Gemma2)
|
||||
@ -1781,25 +1845,22 @@ class HybridCache(Cache):
|
||||
if self.value_cache[layer_idx].device != value_states.device:
|
||||
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
|
||||
|
||||
k_out = self.key_cache[layer_idx]
|
||||
v_out = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
k_cache = self.key_cache[layer_idx]
|
||||
v_cache = self.value_cache[layer_idx]
|
||||
key_states = key_states.to(k_cache.dtype)
|
||||
value_states = value_states.to(v_cache.dtype)
|
||||
|
||||
if sliding_window:
|
||||
update_fn = self._sliding_update
|
||||
if is_sliding_layer:
|
||||
return _sliding_cache_update(
|
||||
k_cache,
|
||||
v_cache,
|
||||
key_states,
|
||||
value_states,
|
||||
cache_position,
|
||||
k_cache.shape[2], # Use actual cache dim as max cache len
|
||||
)
|
||||
else:
|
||||
update_fn = self._static_update
|
||||
|
||||
return update_fn(
|
||||
cache_position,
|
||||
layer_idx,
|
||||
key_states,
|
||||
value_states,
|
||||
k_out,
|
||||
v_out,
|
||||
k_out.shape[2],
|
||||
)
|
||||
return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position)
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
return self.max_cache_len
|
||||
@ -1822,6 +1883,26 @@ class HybridCache(Cache):
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
if self.is_sliding[layer_idx]:
|
||||
query_length = cache_position.shape[0]
|
||||
first_cache_position = cache_position[0]
|
||||
|
||||
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
|
||||
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
|
||||
local_mask_kv_length = max(query_length, self.sliding_window)
|
||||
return local_mask_kv_length, local_mask_kv_offset
|
||||
|
||||
full_mask_kv_offset = 0
|
||||
full_mask_kv_length = self.get_max_cache_shape()
|
||||
return full_mask_kv_length, full_mask_kv_offset
|
||||
|
||||
|
||||
class HybridChunkedCache(Cache):
|
||||
"""
|
||||
@ -1891,11 +1972,11 @@ class HybridChunkedCache(Cache):
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self._dtype = dtype
|
||||
|
||||
if hasattr(config.get_text_config(), "no_rope_layers"):
|
||||
self.is_sliding = config.no_rope_layers
|
||||
# If the attribute does not exist in the config, fallback to a simple StaticCache
|
||||
if hasattr(config, "layer_types"):
|
||||
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
|
||||
else:
|
||||
layer_switch = getattr(config, "sliding_window_pattern", 2)
|
||||
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
|
||||
self.is_sliding = [False] * config.num_hidden_layers
|
||||
|
||||
self.key_cache: List[torch.Tensor] = []
|
||||
self.value_cache: List[torch.Tensor] = []
|
||||
@ -1978,11 +2059,7 @@ class HybridChunkedCache(Cache):
|
||||
key_states = key_states.to(k_out.dtype)
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
if self.is_sliding[layer_idx]:
|
||||
update_fn = self._sliding_update
|
||||
else:
|
||||
update_fn = self._static_update
|
||||
|
||||
update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update
|
||||
return update_fn(
|
||||
cache_position,
|
||||
layer_idx,
|
||||
@ -2017,6 +2094,37 @@ class HybridChunkedCache(Cache):
|
||||
self.value_cache[layer_idx].zero_()
|
||||
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
|
||||
|
||||
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
|
||||
"""
|
||||
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
|
||||
the given layer at `layer_idx`.
|
||||
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
|
||||
for each layer.
|
||||
"""
|
||||
if self.is_sliding[layer_idx]:
|
||||
query_length = cache_position.shape[0]
|
||||
first_cache_position = cache_position[0]
|
||||
|
||||
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
|
||||
# This is the true general case for any Cache using local attention (sliding or chunked)
|
||||
if first_cache_position >= self.sliding_window:
|
||||
# Here the Cache is already full
|
||||
local_mask_kv_length = self.sliding_window + query_length - 1
|
||||
elif (
|
||||
first_cache_position < self.sliding_window
|
||||
and first_cache_position + query_length > self.sliding_window
|
||||
):
|
||||
# Here the Cache becomes full with the new input
|
||||
local_mask_kv_length = first_cache_position + query_length
|
||||
else:
|
||||
# Here the Cache is still smaller than the local size, but we return the local size as it's static
|
||||
local_mask_kv_length = self.sliding_window
|
||||
return local_mask_kv_length, local_mask_kv_offset
|
||||
|
||||
full_mask_kv_offset = 0
|
||||
full_mask_kv_length = self.get_max_cache_shape()
|
||||
return full_mask_kv_length, full_mask_kv_offset
|
||||
|
||||
|
||||
class OffloadedHybridCache(HybridChunkedCache):
|
||||
def __init__(
|
||||
@ -2033,7 +2141,7 @@ class OffloadedHybridCache(HybridChunkedCache):
|
||||
|
||||
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
||||
# track of the original device of each layer
|
||||
unique_devices = set(layer_device_map.values())
|
||||
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
|
||||
if len(unique_devices) > 1:
|
||||
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
|
||||
|
||||
@ -2292,7 +2400,7 @@ class OffloadedStaticCache(StaticCache):
|
||||
|
||||
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
|
||||
# track of the original device of each layer
|
||||
unique_devices = set(layer_device_map.values())
|
||||
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
|
||||
if len(unique_devices) > 1:
|
||||
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
|
||||
|
||||
@ -2369,6 +2477,9 @@ class OffloadedStaticCache(StaticCache):
|
||||
A tuple containing the updated key and value states.
|
||||
"""
|
||||
|
||||
key_states = key_states.to(self.key_cache[layer_idx].dtype)
|
||||
value_states = value_states.to(self.value_cache[layer_idx].dtype)
|
||||
|
||||
if layer_idx == 0:
|
||||
# Update seen tokens.
|
||||
# TODO(gante): Remove this.
|
||||
|
@ -408,6 +408,10 @@ class PretrainedConfig(PushToHubMixin):
|
||||
repo_id = self._create_repo(repo_id, **kwargs)
|
||||
files_timestamps = self._get_files_timestamps(save_directory)
|
||||
|
||||
# This attribute is important to know on load, but should not be serialized on save.
|
||||
if "transformers_weights" in self:
|
||||
delattr(self, "transformers_weights")
|
||||
|
||||
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
@ -1205,3 +1209,16 @@ if PretrainedConfig.push_to_hub.__doc__ is not None:
|
||||
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
||||
object="config", object_class="AutoConfig", object_files="configuration file"
|
||||
)
|
||||
|
||||
|
||||
ALLOWED_LAYER_TYPES = (
|
||||
"full_attention",
|
||||
"sliding_attention",
|
||||
"chunked_attention",
|
||||
)
|
||||
|
||||
|
||||
def layer_type_validation(layer_types: list[str]):
|
||||
"""Check that each entry in `layer_types` are allowed."""
|
||||
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
|
||||
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
|
||||
|
@ -32,7 +32,7 @@ deps = {
|
||||
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
|
||||
"jieba": "jieba",
|
||||
"jinja2": "jinja2>=3.1.0",
|
||||
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5": "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
|
||||
"kenlm": "kenlm",
|
||||
"keras": "keras>2.9,<2.16",
|
||||
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
|
||||
"kernels": "kernels>=0.4.4,<0.5",
|
||||
|
@ -35,6 +35,7 @@ from ..utils import (
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -279,10 +280,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
begin_suppress_tokens (`List[int]`, *optional*):
|
||||
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
|
||||
processor will set their log probs to `-inf` so that they are not sampled.
|
||||
forced_decoder_ids (`List[List[int]]`, *optional*):
|
||||
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
||||
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
|
||||
of index 123.
|
||||
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
|
||||
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||
sequence being selected, while negative biases do the opposite. Check
|
||||
@ -387,12 +384,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
|
||||
specific criteria are met, including using a compilable cache. Please open an issue if you find the
|
||||
need to use this flag.
|
||||
|
||||
> Wild card
|
||||
|
||||
generation_kwargs:
|
||||
Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
|
||||
present in `generate`'s signature will be used in the model forward pass.
|
||||
"""
|
||||
|
||||
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
|
||||
@ -448,7 +439,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
|
||||
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
||||
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
||||
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
||||
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||
self.token_healing = kwargs.pop("token_healing", False)
|
||||
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
||||
@ -493,8 +483,6 @@ class GenerationConfig(PushToHubMixin):
|
||||
# Performance
|
||||
self.compile_config = kwargs.pop("compile_config", None)
|
||||
self.disable_compile = kwargs.pop("disable_compile", False)
|
||||
# Wild card
|
||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
||||
|
||||
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
||||
# interface.
|
||||
@ -514,7 +502,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
raise err
|
||||
|
||||
# Validate the values of the attributes
|
||||
self.validate(is_init=True)
|
||||
self.validate()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.to_json_string(ignore_metadata=True))
|
||||
@ -576,9 +564,10 @@ class GenerationConfig(PushToHubMixin):
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.ASSISTED_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
"is only supported with Greedy Search and Sample. However, the base decoding mode (based on "
|
||||
f"current flags) is {generation_mode} -- some of the set flags will be ignored."
|
||||
)
|
||||
|
||||
# DoLa generation may extend some generation modes
|
||||
@ -586,13 +575,15 @@ class GenerationConfig(PushToHubMixin):
|
||||
if generation_mode in ("greedy_search", "sample"):
|
||||
generation_mode = GenerationMode.DOLA_GENERATION
|
||||
else:
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
"You've set `dola_layers`, which triggers DoLa generate. Currently, DoLa generate "
|
||||
"is only supported with Greedy Search and Sample."
|
||||
"is only supported with Greedy Search and Sample. However, the base decoding mode (based on "
|
||||
f"current flags) is {generation_mode} -- some of the set flags will be ignored."
|
||||
)
|
||||
return generation_mode
|
||||
|
||||
def validate(self, is_init=False):
|
||||
@deprecate_kwarg("is_init", version="4.54.0")
|
||||
def validate(self, strict=False):
|
||||
"""
|
||||
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
|
||||
of parameterization that can be detected as incorrect from the configuration instance alone.
|
||||
@ -600,174 +591,24 @@ class GenerationConfig(PushToHubMixin):
|
||||
Note that some parameters not validated here are best validated at generate runtime, as they may depend on
|
||||
other inputs and/or the model, such as parameters related to the generation length.
|
||||
|
||||
Arg:
|
||||
is_init (`bool`, *optional*, defaults to `False`):
|
||||
Whether the validation is performed during the initialization of the instance.
|
||||
Args:
|
||||
strict (bool): If True, raise an exception for any issues found. If False, only log issues.
|
||||
"""
|
||||
minor_issues = {} # format: {attribute_name: issue_description}
|
||||
|
||||
# Validation of individual attributes
|
||||
# 1. Validation of individual attributes
|
||||
# 1.1. Decoding attributes
|
||||
if self.early_stopping not in {True, False, "never"}:
|
||||
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
|
||||
if self.max_new_tokens is not None and self.max_new_tokens <= 0:
|
||||
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
|
||||
if self.pad_token_id is not None and self.pad_token_id < 0:
|
||||
warnings.warn(
|
||||
minor_issues["pad_token_id"] = (
|
||||
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
|
||||
"generating, if there is padding. Please set `pad_token_id` explicitly as "
|
||||
"`model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation"
|
||||
)
|
||||
|
||||
# Validation of attribute relations:
|
||||
fix_location = ""
|
||||
if is_init:
|
||||
fix_location = (
|
||||
" This was detected when initializing the generation config instance, which means the corresponding "
|
||||
"file may hold incorrect parameterization and should be fixed."
|
||||
)
|
||||
|
||||
# 1. detect sampling-only parameterization when not in sampling mode
|
||||
if self.do_sample is False:
|
||||
greedy_wrong_parameter_msg = (
|
||||
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
||||
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
||||
+ fix_location
|
||||
)
|
||||
if self.temperature is not None and self.temperature != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="temperature", flag_value=self.temperature),
|
||||
UserWarning,
|
||||
)
|
||||
if self.top_p is not None and self.top_p != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.min_p is not None:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p),
|
||||
UserWarning,
|
||||
)
|
||||
if self.typical_p is not None and self.typical_p != 1.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="typical_p", flag_value=self.typical_p),
|
||||
UserWarning,
|
||||
)
|
||||
if (
|
||||
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
||||
): # contrastive search uses top_k
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k),
|
||||
UserWarning,
|
||||
)
|
||||
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
||||
warnings.warn(
|
||||
greedy_wrong_parameter_msg.format(flag_name="eta_cutoff", flag_value=self.eta_cutoff),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 2. detect beam-only parameterization when not in beam mode
|
||||
if self.num_beams is None:
|
||||
warnings.warn("`num_beams` is set to None - defaulting to 1.", UserWarning)
|
||||
self.num_beams = 1
|
||||
|
||||
if self.num_beams == 1:
|
||||
single_beam_wrong_parameter_msg = (
|
||||
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
||||
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`." + fix_location
|
||||
)
|
||||
if self.early_stopping is not False:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="early_stopping", flag_value=self.early_stopping),
|
||||
UserWarning,
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
),
|
||||
UserWarning,
|
||||
)
|
||||
if self.length_penalty is not None and self.length_penalty != 1.0:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="length_penalty", flag_value=self.length_penalty),
|
||||
UserWarning,
|
||||
)
|
||||
if self.constraints is not None:
|
||||
warnings.warn(
|
||||
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 3. detect incorrect parameterization specific to advanced beam modes
|
||||
else:
|
||||
# constrained beam search
|
||||
if self.constraints is not None or self.force_words_ids is not None:
|
||||
constrained_wrong_parameter_msg = (
|
||||
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. However, "
|
||||
"`{flag_name}` is set to `{flag_value}`, which is incompatible with this generation mode. Set "
|
||||
"`constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." + fix_location
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
)
|
||||
# group beam search
|
||||
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
||||
group_error_prefix = (
|
||||
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
||||
"this generation mode, "
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
||||
if self.num_beams % self.num_beam_groups != 0:
|
||||
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
||||
if self.diversity_penalty == 0.0:
|
||||
raise ValueError(
|
||||
group_error_prefix
|
||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||
)
|
||||
# DoLa generation
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
warnings.warn(
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
||||
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
||||
"DoLa decoding is `repetition_penalty>=1.2`.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# 4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
if self.num_beams == 1:
|
||||
if self.do_sample is False:
|
||||
raise ValueError(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
||||
f"(got {self.num_return_sequences})."
|
||||
)
|
||||
elif self.num_return_sequences > self.num_beams:
|
||||
raise ValueError(
|
||||
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# 5. check cache-related arguments
|
||||
# 1.2. Cache attributes
|
||||
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
|
||||
raise ValueError(
|
||||
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
|
||||
@ -784,6 +625,141 @@ class GenerationConfig(PushToHubMixin):
|
||||
if not isinstance(self.cache_config, cache_class):
|
||||
self.cache_config = cache_class.from_dict(self.cache_config)
|
||||
self.cache_config.validate()
|
||||
# 1.3. Performance attributes
|
||||
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
|
||||
raise ValueError(
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
|
||||
"instance of `CompileConfig`."
|
||||
)
|
||||
# 1.4. Watermarking attributes
|
||||
if self.watermarking_config is not None:
|
||||
if not (
|
||||
isinstance(self.watermarking_config, WatermarkingConfig)
|
||||
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
||||
):
|
||||
minor_issues["watermarking_config"] = (
|
||||
"`watermarking_config` as a dict is deprecated and will be removed in v4.54.0. Please construct "
|
||||
"`watermarking_config` object with `WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class."
|
||||
)
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 2. Validation of attribute combinations
|
||||
# 2.1. detect sampling-only parameterization when not in sampling mode
|
||||
if self.do_sample is False:
|
||||
greedy_wrong_parameter_msg = (
|
||||
"`do_sample` is set to `False`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
|
||||
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
|
||||
)
|
||||
if self.temperature is not None and self.temperature != 1.0:
|
||||
minor_issues["temperature"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="temperature", flag_value=self.temperature
|
||||
)
|
||||
if self.top_p is not None and self.top_p != 1.0:
|
||||
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
|
||||
if self.min_p is not None:
|
||||
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
|
||||
if self.typical_p is not None and self.typical_p != 1.0:
|
||||
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="typical_p", flag_value=self.typical_p
|
||||
)
|
||||
if (
|
||||
self.top_k is not None and self.top_k != 50 and self.penalty_alpha is None
|
||||
): # contrastive search uses top_k
|
||||
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)
|
||||
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
|
||||
minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
|
||||
)
|
||||
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
|
||||
minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format(
|
||||
flag_name="eta_cutoff", flag_value=self.eta_cutoff
|
||||
)
|
||||
|
||||
# 2.2. detect beam-only parameterization when not in beam mode
|
||||
if self.num_beams == 1:
|
||||
single_beam_wrong_parameter_msg = (
|
||||
"`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
|
||||
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
|
||||
)
|
||||
if self.early_stopping is not False:
|
||||
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="early_stopping", flag_value=self.early_stopping
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
minor_issues["num_beam_groups"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
if self.diversity_penalty is not None and self.diversity_penalty != 0.0:
|
||||
minor_issues["diversity_penalty"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="diversity_penalty", flag_value=self.diversity_penalty
|
||||
)
|
||||
if self.length_penalty is not None and self.length_penalty != 1.0:
|
||||
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="length_penalty", flag_value=self.length_penalty
|
||||
)
|
||||
if self.constraints is not None:
|
||||
minor_issues["constraints"] = single_beam_wrong_parameter_msg.format(
|
||||
flag_name="constraints", flag_value=self.constraints
|
||||
)
|
||||
# DoLa generation needs num_beams == 1
|
||||
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
|
||||
minor_issues["repetition_penalty"] = (
|
||||
"`dola_layers` is set to trigger DoLa decoding, but `repetition_penalty` is set to a value of "
|
||||
f"{self.repetition_penalty}, which could induce unwanted repetition. The recommended value for "
|
||||
"DoLa decoding is `repetition_penalty>=1.2`.",
|
||||
)
|
||||
|
||||
# 2.3. detect incorrect parameterization specific to advanced beam modes
|
||||
else:
|
||||
# constrained beam search
|
||||
if self.constraints is not None or self.force_words_ids is not None:
|
||||
constrained_wrong_parameter_msg = (
|
||||
"one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. "
|
||||
"However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation "
|
||||
"mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue."
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=self.do_sample)
|
||||
)
|
||||
if self.num_beam_groups is not None and self.num_beam_groups != 1:
|
||||
raise ValueError(
|
||||
constrained_wrong_parameter_msg.format(
|
||||
flag_name="num_beam_groups", flag_value=self.num_beam_groups
|
||||
)
|
||||
)
|
||||
# group beam search
|
||||
elif self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
|
||||
group_error_prefix = (
|
||||
"`diversity_penalty` is not 0.0 or `num_beam_groups` is not 1, triggering group beam search. In "
|
||||
"this generation mode, "
|
||||
)
|
||||
if self.do_sample is True:
|
||||
raise ValueError(group_error_prefix + "`do_sample` must be set to `False`")
|
||||
if self.num_beams % self.num_beam_groups != 0:
|
||||
raise ValueError(group_error_prefix + "`num_beams` should be divisible by `num_beam_groups`")
|
||||
if self.diversity_penalty == 0.0:
|
||||
raise ValueError(
|
||||
group_error_prefix
|
||||
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
|
||||
)
|
||||
|
||||
# 2.4. check `num_return_sequences`
|
||||
if self.num_return_sequences != 1:
|
||||
if self.num_beams == 1:
|
||||
if self.do_sample is False:
|
||||
raise ValueError(
|
||||
"Greedy methods without beam search do not support `num_return_sequences` different than 1 "
|
||||
f"(got {self.num_return_sequences})."
|
||||
)
|
||||
elif self.num_return_sequences > self.num_beams:
|
||||
raise ValueError(
|
||||
f"`num_return_sequences` ({self.num_return_sequences}) has to be smaller or equal to `num_beams` "
|
||||
f"({self.num_beams})."
|
||||
)
|
||||
|
||||
# 2.5. check cache-related arguments
|
||||
if self.use_cache is False:
|
||||
# In this case, all cache-related arguments should be unset. However, since `use_cache=False` is often used
|
||||
# passed to `generate` directly to hot-fix cache issues, let's raise a warning instead of an error
|
||||
@ -794,42 +770,20 @@ class GenerationConfig(PushToHubMixin):
|
||||
)
|
||||
for arg_name in ("cache_implementation", "cache_config", "return_legacy_cache"):
|
||||
if getattr(self, arg_name) is not None:
|
||||
logger.warning_once(
|
||||
no_cache_warning.format(cache_arg=arg_name, cache_arg_value=getattr(self, arg_name))
|
||||
minor_issues[arg_name] = no_cache_warning.format(
|
||||
cache_arg=arg_name, cache_arg_value=getattr(self, arg_name)
|
||||
)
|
||||
|
||||
# 6. check watermarking arguments
|
||||
if self.watermarking_config is not None:
|
||||
if not (
|
||||
isinstance(self.watermarking_config, WatermarkingConfig)
|
||||
or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig)
|
||||
):
|
||||
warnings.warn(
|
||||
"`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with "
|
||||
"`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
|
||||
self.watermarking_config.validate()
|
||||
|
||||
# 7. performances arguments
|
||||
if self.compile_config is not None and not isinstance(self.compile_config, CompileConfig):
|
||||
raise ValueError(
|
||||
f"You provided `compile_config` as an instance of {type(self.compile_config)}, but it must be an "
|
||||
"instance of `CompileConfig`."
|
||||
)
|
||||
|
||||
# 8. other incorrect combinations
|
||||
# 2.6. other incorrect combinations
|
||||
if self.return_dict_in_generate is not True:
|
||||
for extra_output_flag in self.extra_output_flags:
|
||||
if getattr(self, extra_output_flag) is True:
|
||||
warnings.warn(
|
||||
minor_issues[extra_output_flag] = (
|
||||
f"`return_dict_in_generate` is NOT set to `True`, but `{extra_output_flag}` is. When "
|
||||
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored.",
|
||||
UserWarning,
|
||||
f"`return_dict_in_generate` is not `True`, `{extra_output_flag}` is ignored."
|
||||
)
|
||||
|
||||
# 8. check common issue: passing `generate` arguments inside the generation config
|
||||
# 3. Check common issue: passing `generate` arguments inside the generation config
|
||||
generate_arguments = (
|
||||
"logits_processor",
|
||||
"stopping_criteria",
|
||||
@ -839,6 +793,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
"streamer",
|
||||
"negative_prompt_ids",
|
||||
"negative_prompt_attention_mask",
|
||||
"use_model_defaults",
|
||||
)
|
||||
for arg in generate_arguments:
|
||||
if hasattr(self, arg):
|
||||
@ -847,6 +802,30 @@ class GenerationConfig(PushToHubMixin):
|
||||
"`generate()` (or a pipeline) directly."
|
||||
)
|
||||
|
||||
# Finally, handle caught minor issues. With default parameterization, we will throw a minimal warning.
|
||||
if len(minor_issues) > 0:
|
||||
# Full list of issues with potential fixes
|
||||
info_message = []
|
||||
for attribute_name, issue_description in minor_issues.items():
|
||||
info_message.append(f"- `{attribute_name}`: {issue_description}")
|
||||
info_message = "\n".join(info_message)
|
||||
info_message += (
|
||||
"\nIf you're using a pretrained model, note that some of these attributes may be set through the "
|
||||
"model's `generation_config.json` file."
|
||||
)
|
||||
|
||||
if strict:
|
||||
raise ValueError("GenerationConfig is invalid: \n" + info_message)
|
||||
else:
|
||||
attributes_with_issues = list(minor_issues.keys())
|
||||
warning_message = (
|
||||
f"The following generation flags are not valid and may be ignored: {attributes_with_issues}."
|
||||
)
|
||||
if logger.getEffectiveLevel() >= logging.WARNING:
|
||||
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
|
||||
logger.warning(warning_message)
|
||||
logger.info(info_message)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@ -871,18 +850,13 @@ class GenerationConfig(PushToHubMixin):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
|
||||
# At save time, validate the instance -- if any warning/exception is thrown, we refuse to save the instance.
|
||||
# At save time, validate the instance enforcing strictness -- if any warning/exception would be thrown, we
|
||||
# refuse to save the instance.
|
||||
# This strictness is enforced to prevent bad configurations from being saved and re-used.
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.validate()
|
||||
if len(caught_warnings) > 0:
|
||||
raise ValueError(str([w.message for w in caught_warnings]))
|
||||
self.validate(strict=True)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"The generation config instance is invalid -- `.validate()` throws warnings and/or exceptions. "
|
||||
"Fix these issues to save the configuration.\n\nThrown during validation:\n" + str(exc)
|
||||
)
|
||||
raise ValueError(str(exc) + "\n\nFix these issues to save the configuration.")
|
||||
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -25,6 +25,10 @@ from ..utils import add_start_docstrings
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
# TODO (joao): We shouldn't need this, but there would be a circular import
|
||||
if TYPE_CHECKING:
|
||||
from ..generation.configuration_utils import GenerationConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@ -1906,8 +1910,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
||||
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
||||
predicting timestamps that are too far in the future.
|
||||
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
|
||||
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
|
||||
begin_index (`int`):
|
||||
Token index of the first token that is generated by the model.
|
||||
_detect_timestamp_from_logprob (`bool`, *optional*):
|
||||
Whether timestamps can be predicted from logprobs over all timestamps.
|
||||
|
||||
Examples:
|
||||
``` python
|
||||
@ -1940,8 +1946,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generate_config,
|
||||
begin_index: Optional[int] = None,
|
||||
generate_config: "GenerationConfig",
|
||||
begin_index: int,
|
||||
_detect_timestamp_from_logprob: Optional[bool] = None,
|
||||
): # support for the kwargs
|
||||
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
||||
@ -1954,11 +1960,13 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
if _detect_timestamp_from_logprob is not None
|
||||
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
|
||||
)
|
||||
|
||||
num_forced_ids = (
|
||||
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
|
||||
)
|
||||
self.begin_index = begin_index or (num_forced_ids + 1)
|
||||
self.begin_index = begin_index
|
||||
if begin_index is None:
|
||||
raise ValueError(
|
||||
"`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` "
|
||||
"must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` "
|
||||
"was `len(generate_config.forced_decoder_ids)`"
|
||||
)
|
||||
|
||||
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
|
||||
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
|
||||
|
@ -46,6 +46,7 @@ from ..dynamic_module_utils import (
|
||||
)
|
||||
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ..integrations.fsdp import is_fsdp_managed_module
|
||||
from ..masking_utils import create_masks_for_generate
|
||||
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from ..tokenization_utils import ExtensionsTrie
|
||||
@ -74,6 +75,7 @@ from .candidate_generator import (
|
||||
from .configuration_utils import (
|
||||
NEED_SETUP_CACHE_CLASSES_MAPPING,
|
||||
QUANT_BACKEND_CLASSES_MAPPING,
|
||||
CompileConfig,
|
||||
GenerationConfig,
|
||||
GenerationMode,
|
||||
)
|
||||
@ -649,12 +651,22 @@ class GenerationMixin:
|
||||
causal_mask_creation_function = getattr(
|
||||
decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
|
||||
)
|
||||
|
||||
# If it's not defined, it means the model uses the new general mask API
|
||||
if causal_mask_creation_function is None: # can't be found
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
|
||||
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
|
||||
"writing code, see Llama for an example implementation. If you're a user, please report this "
|
||||
"issue on GitHub."
|
||||
output_attentions = kwargs.get("output_attentions", False)
|
||||
token_type_ids = getattr(model_input, "token_type_ids", None)
|
||||
# Some models may overwrite the general one
|
||||
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
|
||||
attention_mask = causal_mask_creation_function(
|
||||
config=self.config,
|
||||
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
|
||||
input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
else:
|
||||
attention_mask = causal_mask_creation_function(
|
||||
@ -1246,12 +1258,6 @@ class GenerationMixin:
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
|
||||
raise ValueError(
|
||||
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
|
||||
"in favour of `input_ids` or `decoder_input_ids` respectively.",
|
||||
)
|
||||
|
||||
# TODO (joao): find a strategy to specify the order of the processors
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
@ -1752,16 +1758,21 @@ class GenerationMixin:
|
||||
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
|
||||
):
|
||||
modified_values = {}
|
||||
default_generation_config = GenerationConfig()
|
||||
for key, default_value in default_generation_config.__dict__.items():
|
||||
global_default_generation_config = GenerationConfig()
|
||||
model_generation_config = self.generation_config
|
||||
# we iterate over the model's generation config: it may hold custom keys, which we'll want to copy
|
||||
for key, model_gen_config_value in model_generation_config.__dict__.items():
|
||||
if key.startswith("_") or key == "transformers_version": # metadata
|
||||
continue
|
||||
custom_gen_config_value = getattr(generation_config, key)
|
||||
model_gen_config_value = getattr(self.generation_config, key)
|
||||
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
|
||||
global_default_value = getattr(global_default_generation_config, key, None)
|
||||
custom_gen_config_value = getattr(generation_config, key, None)
|
||||
if (
|
||||
custom_gen_config_value == global_default_value
|
||||
and model_gen_config_value != global_default_value
|
||||
):
|
||||
modified_values[key] = model_gen_config_value
|
||||
setattr(generation_config, key, model_gen_config_value)
|
||||
if len(modified_values) > 0:
|
||||
if use_model_defaults is None and len(modified_values) > 0:
|
||||
logger.warning_once(
|
||||
f"`generation_config` default values have been modified to match model-specific defaults: "
|
||||
f"{modified_values}. If this is not desired, please set these values explicitly."
|
||||
@ -1980,7 +1991,9 @@ class GenerationMixin:
|
||||
instantiated, writes it to `model_kwargs`, under the name expected by the model.
|
||||
"""
|
||||
|
||||
cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
|
||||
is_hybrid_cache = any(class_name in self.__class__.__name__.lower() for class_name in ["mamba", "falconh1"])
|
||||
cache_name = "past_key_values" if not is_hybrid_cache else "cache_params"
|
||||
|
||||
requires_cross_attention_cache = (
|
||||
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||
)
|
||||
@ -3532,6 +3545,19 @@ class GenerationMixin:
|
||||
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
|
||||
if compile_forward:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
# If we use FA2 and a static cache, we cannot compile with fullgraph
|
||||
if self.config._attn_implementation == "flash_attention_2" and getattr(
|
||||
model_kwargs.get("past_key_values"), "is_compileable", False
|
||||
):
|
||||
if generation_config.compile_config is None:
|
||||
generation_config.compile_config = CompileConfig(fullgraph=False)
|
||||
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
|
||||
elif generation_config.compile_config.fullgraph:
|
||||
logger.warning_once(
|
||||
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
|
||||
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
|
||||
)
|
||||
generation_config.compile_config.fullgraph = False
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
if generation_config.prefill_chunk_size is not None:
|
||||
|
@ -142,7 +142,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["tensor_parallel"] = [
|
||||
"shard_and_distribute_module",
|
||||
"SUPPORTED_TP_STYLES",
|
||||
"ALL_PARALLEL_STYLES",
|
||||
"translate_to_torch_parallel_style",
|
||||
]
|
||||
try:
|
||||
@ -271,7 +271,7 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
else:
|
||||
from .tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
ALL_PARALLEL_STYLES,
|
||||
shard_and_distribute_module,
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user