mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge branch 'main' into fix-eomt-for-pipeline
This commit is contained in:
commit
1357a10c78
2
.github/workflows/check_failed_tests.yml
vendored
2
.github/workflows/check_failed_tests.yml
vendored
@ -41,7 +41,7 @@ jobs:
|
||||
check_new_failures:
|
||||
name: " "
|
||||
runs-on:
|
||||
group: aws-g4dn-4xlarge-cache
|
||||
group: aws-g5-4xlarge-cache
|
||||
container:
|
||||
image: ${{ inputs.docker }}
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
2
.github/workflows/doctest_job.yml
vendored
2
.github/workflows/doctest_job.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
matrix:
|
||||
split_keys: ${{ fromJson(inputs.split_keys) }}
|
||||
runs-on:
|
||||
group: aws-g4dn-4xlarge-cache
|
||||
group: aws-g5-4xlarge-cache
|
||||
container:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
2
.github/workflows/doctests.yml
vendored
2
.github/workflows/doctests.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
setup:
|
||||
name: Setup
|
||||
runs-on:
|
||||
group: aws-g4dn-4xlarge-cache
|
||||
group: aws-g5-4xlarge-cache
|
||||
container:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
4
.github/workflows/model_jobs.yml
vendored
4
.github/workflows/model_jobs.yml
vendored
@ -107,9 +107,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ inputs.machine_type }}"
|
||||
|
||||
if [ "${{ inputs.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ inputs.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ inputs.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ inputs.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ inputs.machine_type }}
|
||||
|
12
.github/workflows/self-comment-ci.yml
vendored
12
.github/workflows/self-comment-ci.yml
vendored
@ -185,7 +185,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.get-tests.outputs.models) }}
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -239,9 +239,9 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -292,7 +292,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.get-tests.outputs.quantizations) }}
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -338,9 +338,9 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
|
26
.github/workflows/self-push.yml
vendored
26
.github/workflows/self-push.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
name: Setup
|
||||
strategy:
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -131,7 +131,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.setup.outputs.matrix) }}
|
||||
machine_type: [aws-g4dn-2xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -169,9 +169,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -244,7 +244,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.setup.outputs.matrix) }}
|
||||
machine_type: [aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -282,9 +282,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -357,7 +357,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-2xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -395,9 +395,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -467,7 +467,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -505,9 +505,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
|
29
.github/workflows/self-scheduled.yml
vendored
29
.github/workflows/self-scheduled.yml
vendored
@ -50,7 +50,7 @@ jobs:
|
||||
name: Setup
|
||||
strategy:
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -128,13 +128,14 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
slice_id: [0, 1]
|
||||
uses: ./.github/workflows/model_jobs.yml
|
||||
with:
|
||||
folder_slices: ${{ needs.setup.outputs.folder_slices }}
|
||||
machine_type: ${{ matrix.machine_type }}
|
||||
slice_id: ${{ matrix.slice_id }}
|
||||
runner_map: ${{ needs.setup.outputs.runner_map }}
|
||||
docker: ${{ inputs.docker }}
|
||||
report_name_prefix: run_trainer_and_fsdp_gpu
|
||||
secrets: inherit
|
||||
@ -145,7 +146,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -179,9 +180,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -213,7 +214,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -247,9 +248,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -282,7 +283,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -344,9 +345,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -381,7 +382,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.setup.outputs.quantization_matrix) }}
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -424,9 +425,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
|
@ -288,7 +288,7 @@ Keywords: Music understanding, Music generation
|
||||
|
||||
## [dalle-flow](https://github.com/jina-ai/dalle-flow)
|
||||
|
||||
DALL·E Flow is an interactive workflow for generating high-definition images from a text prompt. Itt leverages DALL·E-Mega, GLID-3 XL, and Stable Diffusion to generate image candidates, and then calls CLIP-as-service to rank the candidates w.r.t. the prompt.
|
||||
DALL·E Flow is an interactive workflow for generating high-definition images from a text prompt. It leverages DALL·E-Mega, GLID-3 XL, and Stable Diffusion to generate image candidates, and then calls CLIP-as-service to rank the candidates w.r.t. the prompt.
|
||||
The preferred candidate is fed to GLID-3 XL for diffusion, which often enriches the texture and background. Finally, the candidate is upscaled to 1024x1024 via SwinIR.
|
||||
|
||||
Keywords: High-definition image generation, Stable Diffusion, DALL-E Mega, GLID-3 XL, CLIP, SwinIR
|
||||
@ -526,7 +526,7 @@ Keywords: Model deployment, CLoud, Mobile, Edge
|
||||
|
||||
## [underthesea](https://github.com/undertheseanlp/underthesea)
|
||||
|
||||
[underthesea](https://github.com/undertheseanlp/underthesea) is a Vietnamese NLP toolkit. Underthesea is a suite of open source Python modules data sets and tutorials supporting research and development in Vietnamese Natural Language Processing. We provides extremely easy API to quickly apply pretrained NLP models to your Vietnamese text, such as word segmentation, part-of-speech tagging (PoS), named entity recognition (NER), text classification and dependency parsing.
|
||||
[underthesea](https://github.com/undertheseanlp/underthesea) is a Vietnamese NLP toolkit. Underthesea is a suite of open source Python modules data sets and tutorials supporting research and development in Vietnamese Natural Language Processing. We provide extremely easy API to quickly apply pretrained NLP models to your Vietnamese text, such as word segmentation, part-of-speech tagging (PoS), named entity recognition (NER), text classification and dependency parsing.
|
||||
|
||||
Keywords: Vietnamese, NLP
|
||||
|
||||
|
@ -56,7 +56,7 @@ Create a [`ImageTextToTextPipeline`] and pass the chat to it. For large models,
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda", torch_dtype=torch.float16)
|
||||
pipeline = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device_map="auto", torch_dtype=torch.float16)
|
||||
pipeline(text=messages, max_new_tokens=50, return_full_text=False)
|
||||
[{'input_text': [{'role': 'system',
|
||||
'content': [{'type': 'text',
|
||||
@ -175,7 +175,7 @@ processed_chat = processor.apply_chat_template(
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=32,
|
||||
video_fps=16,
|
||||
video_load_backend="decord",
|
||||
)
|
||||
print(processed_chat.keys())
|
||||
|
@ -27,6 +27,9 @@ This guide shows you how to quickly start chatting with Transformers from the co
|
||||
|
||||
## transformers CLI
|
||||
|
||||
|
||||
### Interactive chat session
|
||||
|
||||
After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
|
||||
|
||||
```bash
|
||||
@ -51,6 +54,68 @@ transformers chat -h
|
||||
|
||||
The chat is implemented on top of the [AutoClass](./model_doc/auto), using tooling from [text generation](./llm_tutorial) and [chat](./chat_templating).
|
||||
|
||||
|
||||
### Serving a model and using MCP tools
|
||||
|
||||
> [!WARNING]
|
||||
> This section is experimental and subject to changes in future versions
|
||||
|
||||
Powering the `chat` interface, we have a server that takes user messages and returns completions. The server has a chat completion API compatible with the OpenAI SDK, so you can also quickly experiment with `transformers` models on existing aplications. To launch a server separately, use the `transformers serve` CLI:
|
||||
|
||||
```bash
|
||||
transformers serve Menlo/Jan-nano
|
||||
```
|
||||
|
||||
Under the hood, the `chat` CLI launches and uses `transformers serve`. This server is also an MCP client, which can receive information available MCP servers (i.e. tools), massage their information into the model prompt, and prepare calls to these tools when the model commands to do so. Naturally, this requires a model that is trained to use tools.
|
||||
|
||||
At the moment, MCP tool usage in `transformers` has the following constraints:
|
||||
- `chat` can't handle tools, but the [`tiny-agents`](https://huggingface.co/blog/python-tiny-agents) CLI can;
|
||||
- Only the `qwen` family of models is supported.
|
||||
|
||||
The first step to use MCP tools is to let the model know which tools are available. As an example, let's consider a `tiny-agents` configuration file with a reference to an [image generation MCP server](https://evalstate-flux1-schnell.hf.space/).
|
||||
|
||||
> [!TIP]
|
||||
> Many Hugging Face Spaces can be used as MCP servers. You can find all compatible Spaces [here](https://huggingface.co/spaces?filter=mcp-server).
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "http://localhost:8000",
|
||||
"provider": "local",
|
||||
"servers": [
|
||||
{
|
||||
"type": "sse",
|
||||
"config": {
|
||||
"url": "https://evalstate-flux1-schnell.hf.space/gradio_api/mcp/sse"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
You can then launch your `tiny-agents` chat interface with the following command.
|
||||
|
||||
```bash
|
||||
tiny-agents run path/to/your/config.json
|
||||
```
|
||||
|
||||
If you have a server (from `transformers serve`) running in the background, you're ready to use MCP tools from a local model! For instance, here's the example of a chat session:
|
||||
|
||||
```bash
|
||||
Agent loaded with 1 tools:
|
||||
• flux1_schnell_infer
|
||||
» Generate an image of a cat on the moon
|
||||
<Tool req_0_tool_call>flux1_schnell_infer {"prompt": "a cat on the moon", "seed": 42, "randomize_seed": true, "width": 1024, "height": 1024, "num_inference_steps": 4}
|
||||
|
||||
Tool req_0_tool_call
|
||||
[Binary Content: Image image/webp, 57732 bytes]
|
||||
The task is complete and the content accessible to the User
|
||||
Image URL: https://evalstate-flux1-schnell.hf.space/gradio_api/file=/tmp/gradio/3dbddc0e53b5a865ed56a4e3dbdd30f3f61cf3b8aabf1b456f43e5241bd968b8/image.webp
|
||||
380576952
|
||||
|
||||
I have generated an image of a cat on the moon using the Flux 1 Schnell Image Generator. The image is 1024x1024 pixels and was created with 4 inference steps. Let me know if you would like to make any changes or need further assistance!
|
||||
```
|
||||
|
||||
|
||||
## TextGenerationPipeline
|
||||
|
||||
[`TextGenerationPipeline`] is a high-level text generation class with a "chat mode". Chat mode is enabled when a conversational model is detected and the chat prompt is [properly formatted](./llm_tutorial#wrong-prompt-format).
|
||||
|
@ -26,6 +26,7 @@ Pass the audio signal, typically stored in `array`, to the feature extractor and
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train")
|
||||
processed_sample = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
|
||||
processed_sample
|
||||
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
|
||||
|
@ -14,59 +14,123 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# BigBirdPegasus
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# BigBirdPegasus
|
||||
|
||||
The BigBird model was proposed in [Big Bird: Transformers for Longer Sequences](https://huggingface.co/papers/2007.14062) by
|
||||
Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon,
|
||||
Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention
|
||||
based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse
|
||||
attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it
|
||||
has been shown that applying sparse, global, and random attention approximates full attention, while being
|
||||
computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context,
|
||||
BigBird has shown improved performance on various long document NLP tasks, such as question answering and
|
||||
summarization, compared to BERT or RoBERTa.
|
||||
[BigBirdPegasus](https://huggingface.co/papers/2007.14062) is an encoder-decoder (sequence-to-sequence) transformer model for long-input summarization. It extends the [BigBird](./big_bird) architecture with an additional pretraining objective borrowed from [Pegasus](./pegasus) called gap sequence generation (GSG). Whole sentences are masked and the model has to fill in the gaps in the document. BigBirdPegasus's ability to keep track of long contexts makes it effective at summarizing lengthy inputs, surpassing the performance of base Pegasus models.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original BigBirdPegasus checkpoints under the [Google](https://huggingface.co/google/models?search=bigbird-pegasus) organization.
|
||||
|
||||
*Transformers-based models, such as BERT, have been one of the most successful deep learning models for NLP.
|
||||
Unfortunately, one of their core limitations is the quadratic dependency (mainly in terms of memory) on the sequence
|
||||
length due to their full attention mechanism. To remedy this, we propose, BigBird, a sparse attention mechanism that
|
||||
reduces this quadratic dependency to linear. We show that BigBird is a universal approximator of sequence functions and
|
||||
is Turing complete, thereby preserving these properties of the quadratic, full attention model. Along the way, our
|
||||
theoretical analysis reveals some of the benefits of having O(1) global tokens (such as CLS), that attend to the entire
|
||||
sequence as part of the sparse attention mechanism. The proposed sparse attention can handle sequences of length up to
|
||||
8x of what was previously possible using similar hardware. As a consequence of the capability to handle longer context,
|
||||
BigBird drastically improves performance on various NLP tasks such as question answering and summarization. We also
|
||||
propose novel applications to genomics data.*
|
||||
> [!TIP]
|
||||
> This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta).
|
||||
>
|
||||
> Click on the BigBirdPegasus models in the right sidebar for more examples of how to apply BigBirdPegasus to different language tasks.
|
||||
|
||||
The original code can be found [here](https://github.com/google-research/bigbird).
|
||||
The example below demonstrates how to summarize text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
- For an in-detail explanation on how BigBird's attention works, see [this blog post](https://huggingface.co/blog/big-bird).
|
||||
- BigBird comes with 2 implementations: **original_full** & **block_sparse**. For the sequence length < 1024, using
|
||||
**original_full** is advised as there is no benefit in using **block_sparse** attention.
|
||||
- The code currently uses window size of 3 blocks and 2 global blocks.
|
||||
- Sequence length must be divisible by block size.
|
||||
- Current implementation supports only **ITC**.
|
||||
- Current implementation doesn't support **num_random_blocks = 0**.
|
||||
- BigBirdPegasus uses the [PegasusTokenizer](https://github.com/huggingface/transformers/blob/main/src/transformers/models/pegasus/tokenization_pegasus.py).
|
||||
- BigBird is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="summarization",
|
||||
model="google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.float32,
|
||||
device=0
|
||||
)
|
||||
pipeline("""Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle.""")
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
input_text = """Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle."""
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet. Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts." | transformers-cli run --task summarization --model google/bigbird-pegasus-large-arxiv --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv"
|
||||
)
|
||||
|
||||
input_text = """Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle."""
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- BigBirdPegasus also uses the [`PegasusTokenizer`].
|
||||
- Inputs should be padded on the right because BigBird uses absolute position embeddings.
|
||||
- BigBirdPegasus supports `original_full` and `block_sparse` attention. If the input sequence length is less than 1024, it is recommended to use `original_full` since sparse patterns don't offer much benefit for smaller inputs.
|
||||
- The current implementation uses window size of 3 blocks and 2 global blocks, only supports the ITC-implementation, and doesn't support `num_random_blocks=0`.
|
||||
- The sequence length must be divisible by the block size.
|
||||
|
||||
## Resources
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
- [Translation task guide](../tasks/translation)
|
||||
- [Summarization task guide](../tasks/summarization)
|
||||
Read the [Understanding BigBird's Block Sparse Attention](https://huggingface.co/blog/big-bird) blog post for more details about how BigBird's attention works.
|
||||
|
||||
## BigBirdPegasusConfig
|
||||
|
||||
|
@ -32,8 +32,8 @@ this model, including [Alternating Updates][altup] (AltUp), [Learned Augmented R
|
||||
[MatFormer][matformer], Per-Layer Embeddings (PLE), activation sparsity, and KV cache sharing. The language model uses
|
||||
a similar attention pattern to [Gemma 3](./gemma3.md) with alternating 4 local sliding window self-attention layers for
|
||||
every global self-attention layer with a maximum context length of 32k tokens. Gemma 3n introduces
|
||||
[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a
|
||||
[Universal Speech Model][usm] (USM) as the audio encoder.
|
||||
[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a newly
|
||||
trained audio encoder based on the [Universal Speech Model][usm] (USM) architecture.
|
||||
|
||||
The instruction-tuned variant was post-trained with knowledge distillation and reinforcement learning.
|
||||
|
||||
|
@ -15,9 +15,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Distributed inference
|
||||
|
||||
When a model doesn't fit on a single GPU, distributed inference with [tensor parallelism](./perf_train_gpu_many#tensor-parallelism) can help. Tensor parallelism shards a model onto multiple GPUs and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each GPU can process a tensor slice.
|
||||
When a model doesn't fit on a single GPU, distributed inference with [tensor parallelism](./perf_train_gpu_many#tensor-parallelism) can help. Tensor parallelism shards a model onto multiple accelerators (CUDA GPU, Intel XPU, etc.) and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each accelerator can process a tensor slice.
|
||||
|
||||
However, tensor parallelism adds communication overhead and should be used on single machine setups with multiple GPUs to take advantage of fast intra-node communication. For multi-node training, it may be more efficient to use pipeline or data parallelism depending on your use case.
|
||||
However, tensor parallelism adds communication overhead and should be used on single machine setups with multiple accelerators to take advantage of fast intra-node communication. For multi-node training, it may be more efficient to use pipeline or data parallelism depending on your use case.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) section on tensor parallelism to learn more.
|
||||
@ -308,4 +308,4 @@ The most important part of DTensor is the `placement` attribute because it tells
|
||||
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
|
||||
```
|
||||
|
||||
- `Partial()` - Indicates a tensor is pending a reduction operation (not typically relevant for usage in Transformers).
|
||||
- `Partial()` - Indicates a tensor is pending a reduction operation (not typically relevant for usage in Transformers).
|
||||
|
@ -47,7 +47,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device.type)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
@ -49,6 +49,7 @@ Check the table below to see if your hardware is compatible.
|
||||
| Component | Compatibility |
|
||||
|----------|----------------|
|
||||
| CUDA Versions | ✅ cu118, cu126, cu128 |
|
||||
| XPU Versions | ✅ pytorch2.8 |
|
||||
| CPU | ✅ change `device_map="cpu"` (see examples below) |
|
||||
|
||||
|
||||
@ -278,6 +279,71 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Intel XPU
|
||||
<hfoptions id="examples-Intel-XPU">
|
||||
<hfoption id="int8-dynamic-and-weight-only">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig
|
||||
|
||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||
# or int8 weight only quantization
|
||||
# quant_config = Int8WeightOnlyConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("xpu")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="int4-weight-only">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
from torchao.dtypes import Int4XPULayout
|
||||
from torchao.quantization.quant_primitives import ZeroPointDomain
|
||||
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("xpu")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### CPU
|
||||
<hfoptions id="examples-CPU">
|
||||
<hfoption id="int8-dynamic-and-weight-only">
|
||||
@ -363,7 +429,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Manual Testing
|
||||
prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(quantized_model.device.type)
|
||||
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
|
||||
output_text = tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
@ -434,7 +500,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device.type)
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
@ -474,7 +540,7 @@ tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
|
||||
|
||||
## Loading quantized models
|
||||
|
||||
Loading a quantized model depends on the quantization scheme. For quantization schemes, like int8 and float8, you can quantize the model on any device and also load it on any device. The example below demonstrates quantizing a model on the CPU and then loading it on CUDA.
|
||||
Loading a quantized model depends on the quantization scheme. For quantization schemes, like int8 and float8, you can quantize the model on any device and also load it on any device. The example below demonstrates quantizing a model on the CPU and then loading it on CUDA or XPU.
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@ -491,7 +557,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
# save the quantized model
|
||||
output_dir = "llama-3.1-8b-torchao-int8-cuda"
|
||||
output_dir = "llama-3.1-8b-torchao-int8"
|
||||
quantized_model.save_pretrained(output_dir, safe_serialization=False)
|
||||
|
||||
# reload the quantized model
|
||||
@ -502,7 +568,7 @@ reloaded_model = AutoModelForCausalLM.from_pretrained(
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(reloaded_model.device.type)
|
||||
|
||||
output = reloaded_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
2
setup.py
2
setup.py
@ -148,7 +148,7 @@ _deps = [
|
||||
"protobuf",
|
||||
"psutil",
|
||||
"pyyaml>=5.1",
|
||||
"pydantic",
|
||||
"pydantic>=2",
|
||||
"pytest>=7.2.0",
|
||||
"pytest-asyncio",
|
||||
"pytest-rerunfailures",
|
||||
|
@ -13,33 +13,30 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import yaml
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
from huggingface_hub import AsyncInferenceClient, ChatCompletionStreamOutput
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
TextIteratorStreamer,
|
||||
logging,
|
||||
)
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.commands.serving import ServeArguments, ServeCommand
|
||||
from transformers.utils import is_rich_available, is_torch_available
|
||||
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import pwd
|
||||
@ -52,8 +49,12 @@ if is_rich_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
||||
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
||||
ALLOWED_VALUE_CHARS = set(
|
||||
@ -107,19 +108,6 @@ If you're a new user, check this basic flag guide: https://huggingface.co/docs/t
|
||||
- **!exit**: closes the interface
|
||||
"""
|
||||
|
||||
# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag)
|
||||
_DEPRECATION_MAP = [
|
||||
("max_new_tokens", 256, "max_new_tokens"),
|
||||
("do_sample", True, "do_sample"),
|
||||
("num_beams", 1, "num_beams"),
|
||||
("temperature", 1.0, "temperature"),
|
||||
("top_k", 50, "top_k"),
|
||||
("top_p", 1.0, "top_p"),
|
||||
("repetition_penalty", 1.0, "repetition_penalty"),
|
||||
("eos_tokens", None, "eos_token_id"),
|
||||
("eos_token_ids", None, "eos_token_id"),
|
||||
]
|
||||
|
||||
|
||||
class RichInterface:
|
||||
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
|
||||
@ -133,21 +121,21 @@ class RichInterface:
|
||||
else:
|
||||
self.user_name = user_name
|
||||
|
||||
def stream_output(self, output_stream: TextIteratorStreamer) -> str:
|
||||
"""Stream output from a role, and return the generated text after it's done steaming."""
|
||||
# This method is originally from the FastChat CLI:
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||
# Create a Live context for updating the console output
|
||||
text = ""
|
||||
async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, int]:
|
||||
self._console.print(f"[bold blue]<{self.model_name}>:")
|
||||
with Live(console=self._console, refresh_per_second=4) as live:
|
||||
# Read lines from the stream
|
||||
for i, outputs in enumerate(output_stream):
|
||||
if not outputs or i == 0:
|
||||
text = ""
|
||||
async for token in await stream:
|
||||
outputs = token.choices[0].delta.content
|
||||
request_id = token.id
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
|
||||
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
|
||||
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
|
||||
|
||||
text += outputs
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
@ -160,6 +148,7 @@ class RichInterface:
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
lines.append(line)
|
||||
@ -169,11 +158,15 @@ class RichInterface:
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
|
||||
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
||||
|
||||
# Update the Live console output
|
||||
live.update(markdown)
|
||||
live.update(markdown, refresh=True)
|
||||
|
||||
self._console.print()
|
||||
return text
|
||||
|
||||
return text, request_id
|
||||
|
||||
def input(self) -> str:
|
||||
"""Gets user input from the console."""
|
||||
@ -245,25 +238,6 @@ class ChatArguments:
|
||||
),
|
||||
},
|
||||
)
|
||||
# Deprecated CLI args start here
|
||||
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
||||
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
||||
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
||||
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation."})
|
||||
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling."})
|
||||
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling."})
|
||||
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
|
||||
eos_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated."
|
||||
},
|
||||
)
|
||||
eos_token_ids: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
||||
)
|
||||
# Deprecated CLI args end here
|
||||
|
||||
# Model loading
|
||||
model_revision: str = field(
|
||||
@ -300,6 +274,10 @@ class ChatArguments:
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
|
||||
|
||||
# Serving settings
|
||||
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
|
||||
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
|
||||
|
||||
|
||||
def chat_command_factory(args: Namespace):
|
||||
"""
|
||||
@ -322,7 +300,10 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
group = chat_parser.add_argument_group("Positional arguments")
|
||||
group.add_argument(
|
||||
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
|
||||
"model_name_or_path_or_address",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the pre-trained model or address to connect to.",
|
||||
)
|
||||
group.add_argument(
|
||||
"generate_flags",
|
||||
@ -332,57 +313,45 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
|
||||
"and lists of integers, more advanced parameterization should be set through --generation-config. "
|
||||
"Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
|
||||
"If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options"
|
||||
"If you're a new user, check this basic flag guide: "
|
||||
"https://huggingface.co/docs/transformers/llm_tutorial#common-options"
|
||||
),
|
||||
nargs="*",
|
||||
)
|
||||
chat_parser.set_defaults(func=chat_command_factory)
|
||||
|
||||
def __init__(self, args):
|
||||
args = self._handle_deprecated_args(args)
|
||||
if args.model_name_or_path_or_address is not None:
|
||||
name = args.model_name_or_path_or_address
|
||||
if name.startswith("http") or name.startswith("https") or name.startswith("localhost"):
|
||||
self.spawn_backend = False
|
||||
|
||||
if args.host != "localhost" or args.port != 8000:
|
||||
raise ValueError(
|
||||
"Looks like you’ve set both a server address and a custom host/port. "
|
||||
"Please pick just one way to specify the server."
|
||||
)
|
||||
|
||||
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
|
||||
else:
|
||||
self.spawn_backend = True
|
||||
args.model_name_or_path = args.model_name_or_path_or_address
|
||||
|
||||
if not is_rich_available() and (not is_torch_available() and self.spawn_backend):
|
||||
raise ImportError(
|
||||
"You need to install rich to use the chat interface. Additionally, you have not specified a remote "
|
||||
"endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`)"
|
||||
)
|
||||
elif not is_rich_available():
|
||||
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
|
||||
elif not is_torch_available() and self.spawn_backend:
|
||||
raise ImportError(
|
||||
"You have not specified a remote endpoint and are therefore spawning a backend. Torch is required "
|
||||
"for this: (`pip install rich torch`)"
|
||||
)
|
||||
|
||||
self.args = args
|
||||
|
||||
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
|
||||
"""
|
||||
Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
|
||||
args.
|
||||
"""
|
||||
has_warnings = False
|
||||
|
||||
# 1. Model as a positional argument
|
||||
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
|
||||
if args.model_name_or_path_positional is None:
|
||||
raise ValueError(
|
||||
"One of the following must be provided:"
|
||||
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
|
||||
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
|
||||
)
|
||||
elif args.model_name_or_path is not None:
|
||||
has_warnings = True
|
||||
warnings.warn(
|
||||
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
|
||||
"argument instead, e.g. `transformers chat <model_repo>`.",
|
||||
FutureWarning,
|
||||
)
|
||||
# 2. Named generate option args
|
||||
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
|
||||
value = getattr(args, deprecated_arg)
|
||||
if value != default_value:
|
||||
has_warnings = True
|
||||
warnings.warn(
|
||||
f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two "
|
||||
"alternative solutions to specify this generation option: \n"
|
||||
"1. Pass `--generation-config <path_to_file/Hub repo>` to specify a generation config.\n"
|
||||
"2. Pass `generate` flags through positional arguments, e.g. `transformers chat <model_repo> "
|
||||
f"{new_arg}={value}`",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if has_warnings:
|
||||
print("\n(Press enter to continue)")
|
||||
input()
|
||||
return args
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Chat session methods
|
||||
@staticmethod
|
||||
@ -404,7 +373,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
if filename is None:
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
|
||||
filename = f"{args.model_name_or_path_or_address}/chat_{time_str}.json"
|
||||
filename = os.path.join(folder, filename)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
@ -477,40 +446,23 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
return processed_generate_flags
|
||||
|
||||
def get_generation_parameterization(
|
||||
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
|
||||
) -> tuple[GenerationConfig, dict]:
|
||||
def get_generation_parameterization(self, args: ChatArguments) -> tuple[GenerationConfig, dict]:
|
||||
"""
|
||||
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
||||
"""
|
||||
# No generation config arg provided -> use default generation config, apply CLI defaults
|
||||
if args.generation_config is None:
|
||||
# We start off from the checkpoint's generation config
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
# Apply deprecated CLI args on top of the default generation config
|
||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(
|
||||
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
|
||||
)
|
||||
deprecated_kwargs = {
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"do_sample": args.do_sample,
|
||||
"num_beams": args.num_beams,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
"top_p": args.top_p,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
"pad_token_id": pad_token_id,
|
||||
"eos_token_id": eos_token_ids,
|
||||
}
|
||||
generation_config.update(**deprecated_kwargs)
|
||||
# generation config arg provided -> use it as the base parameterization
|
||||
else:
|
||||
# No generation config arg provided -> use base generation config, apply CLI defaults
|
||||
if args.generation_config is not None:
|
||||
if ".json" in args.generation_config: # is a local file
|
||||
dirname = os.path.dirname(args.generation_config)
|
||||
filename = os.path.basename(args.generation_config)
|
||||
generation_config = GenerationConfig.from_pretrained(dirname, filename)
|
||||
else:
|
||||
generation_config = GenerationConfig.from_pretrained(args.generation_config)
|
||||
else:
|
||||
# !!!!!!!!!
|
||||
# This is a chat session, so we have a few non-standard defaults
|
||||
# !!!!!!!!!
|
||||
generation_config = GenerationConfig(do_sample=True, max_new_tokens=256)
|
||||
|
||||
# Finally: parse and apply `generate_flags`
|
||||
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
|
||||
@ -664,7 +616,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
elif user_input == "!status":
|
||||
interface.print_status(
|
||||
model_name=args.model_name_or_path_positional,
|
||||
model_name=args.model_name_or_path,
|
||||
generation_config=generation_config,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
@ -679,10 +631,32 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Main logic
|
||||
def run(self):
|
||||
if not is_rich_available():
|
||||
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
|
||||
if not is_torch_available():
|
||||
raise ImportError("You need to install torch to use the chat interface. (`pip install torch`)")
|
||||
asyncio.run(self._inner_run())
|
||||
|
||||
async def _inner_run(self):
|
||||
if self.spawn_backend:
|
||||
serve_args = ServeArguments(
|
||||
device=self.args.device,
|
||||
torch_dtype=self.args.torch_dtype,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
attn_implementation=self.args.attn_implementation,
|
||||
load_in_8bit=self.args.load_in_8bit,
|
||||
load_in_4bit=self.args.load_in_4bit,
|
||||
bnb_4bit_quant_type=self.args.bnb_4bit_quant_type,
|
||||
use_bnb_nested_quant=self.args.use_bnb_nested_quant,
|
||||
host=self.args.host,
|
||||
port=self.args.port,
|
||||
log_level="error",
|
||||
)
|
||||
serve_command = ServeCommand(serve_args)
|
||||
|
||||
thread = Thread(target=serve_command.run)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
model = self.args.model_name_or_path + "@" + self.args.model_revision
|
||||
host = "http://localhost" if self.args.host == "localhost" else self.args.host
|
||||
client = AsyncInferenceClient(f"{host}:{self.args.port}")
|
||||
|
||||
args = self.args
|
||||
if args.examples_path is None:
|
||||
@ -696,19 +670,14 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
else:
|
||||
user = args.user
|
||||
|
||||
model, tokenizer = self.load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args)
|
||||
|
||||
# if not verbose -> disable warnings, progress bars, etc in the chat interface
|
||||
if not args.verbose:
|
||||
logging.set_verbosity_error()
|
||||
disable_progress_bars()
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface.clear()
|
||||
chat = self.clear_chat_history(args.system_prompt)
|
||||
|
||||
request_id = None
|
||||
|
||||
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
||||
interface.print_help(minimal=True)
|
||||
while True:
|
||||
@ -736,23 +705,29 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
else:
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
|
||||
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
stream = client.chat_completion(
|
||||
chat,
|
||||
stream=True,
|
||||
extra_body={
|
||||
"request_id": request_id,
|
||||
"generation_config": {**generation_config.to_dict()},
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
attention_mask = torch.ones_like(inputs)
|
||||
generation_kwargs = {
|
||||
"inputs": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"streamer": generation_streamer,
|
||||
"generation_config": generation_config,
|
||||
**model_kwargs,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
model_output = interface.stream_output(generation_streamer)
|
||||
thread.join()
|
||||
model_output, request_id = await interface.stream_output(stream)
|
||||
|
||||
chat.append({"role": "assistant", "content": model_output})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = ChatArguments()
|
||||
args.model_name_or_path_or_address = "meta-llama/Llama-3.2-3b-Instruct"
|
||||
args.model_name_or_path_or_address = "http://localhost:8000"
|
||||
chat = ChatCommand(args)
|
||||
chat.run()
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -11,33 +11,95 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..pipelines import Pipeline, get_supported_tasks, pipeline
|
||||
from ..utils import logging
|
||||
from huggingface_hub import (
|
||||
ChatCompletionStreamOutputChoice,
|
||||
ChatCompletionStreamOutputDelta,
|
||||
ChatCompletionStreamOutputDeltaToolCall,
|
||||
ChatCompletionStreamOutputFunction,
|
||||
ModelInfo,
|
||||
model_info,
|
||||
)
|
||||
|
||||
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
||||
|
||||
from .. import PreTrainedTokenizerFast, TextIteratorStreamer
|
||||
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
|
||||
from ..utils import is_torch_available, logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
try:
|
||||
from fastapi import Body, FastAPI, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available():
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import JSONResponse
|
||||
from uvicorn import run
|
||||
|
||||
_serve_dependencies_installed = True
|
||||
except (ImportError, AttributeError):
|
||||
BaseModel = object
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def Body(*x, **y):
|
||||
pass
|
||||
class ChatCompletionInput(BaseModel):
|
||||
messages: list[Message]
|
||||
|
||||
_serve_dependencies_installed = False
|
||||
stream: Optional[bool] = False
|
||||
model: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
extra_body: Optional[dict] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[list[float]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[list[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
# Additional options supported by the HFH InferenceClient
|
||||
# that aren't yet supported here.
|
||||
|
||||
# logprobs: Optional[bool] = None
|
||||
tools: Any = None
|
||||
# n: Optional[int] = None
|
||||
# presence_penalty: Optional[float] = None
|
||||
# response_format: Optional[ChatCompletionInputGrammarType] = None
|
||||
# stream_options: Optional[ChatCompletionInputStreamOptions] = None
|
||||
# tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None
|
||||
# tool_prompt: Optional[str] = None
|
||||
# top_logprobs: Optional[int] = None
|
||||
|
||||
|
||||
logger = logging.get_logger("transformers/serving")
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Possible tokens that indicate the start/end of a tool call
|
||||
# TODO (joao, matt): streamline tool token detection logic
|
||||
_TOOL_CALL_TOKENS = {
|
||||
"qwen": {
|
||||
"start": "<tool_call>",
|
||||
"end": "</tool_call>",
|
||||
},
|
||||
}
|
||||
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
|
||||
|
||||
|
||||
def serve_command_factory(args: Namespace):
|
||||
@ -46,50 +108,114 @@ def serve_command_factory(args: Namespace):
|
||||
|
||||
Returns: ServeCommand
|
||||
"""
|
||||
nlp = pipeline(
|
||||
task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device,
|
||||
return ServeCommand(args)
|
||||
|
||||
|
||||
def create_generation_config_from_req(req: "ChatCompletionInput") -> "GenerationConfig":
|
||||
"""
|
||||
Creates a generation config from the parameters of the request. Note that we can pass a `GenerationConfig`
|
||||
(serialized into a `dict`) in `extra_body`, for full `generate` parameterization.
|
||||
|
||||
Args:
|
||||
req (`ChatCompletionInput`): The request which may optionally contain generation parameters.
|
||||
|
||||
Returns:
|
||||
The prepared `GenerationConfig` object.
|
||||
"""
|
||||
if req.extra_body is not None and "generation_config" in req.extra_body:
|
||||
for key in req.extra_body["generation_config"].keys():
|
||||
if key in ChatCompletionInput.base_field_names.keys():
|
||||
return {"error": "Duplicated key in the root request and in the passed generation config."}
|
||||
|
||||
if req.extra_body is not None and "generation_config" in req.extra_body:
|
||||
generation_config = GenerationConfig(**(req.extra_body["generation_config"]))
|
||||
else:
|
||||
generation_config = GenerationConfig()
|
||||
|
||||
if req.frequency_penalty is not None:
|
||||
generation_config.repetition_penalty = req.frequency_penalty
|
||||
if req.logit_bias is not None:
|
||||
generation_config.sequence_bias = req.logit_bias
|
||||
if req.stop is not None:
|
||||
generation_config.stop_strings = req.stop
|
||||
if req.temperature is not None:
|
||||
generation_config.temperature = req.temperature
|
||||
if req.top_p is not None:
|
||||
generation_config.top_p = req.top_p
|
||||
if req.seed is not None:
|
||||
torch.manual_seed(req.seed)
|
||||
|
||||
return generation_config
|
||||
|
||||
|
||||
class ToolState:
|
||||
"""Lightweight class to keep track of the tool call state."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the tool call state (assumes we're outside a tool call)."""
|
||||
self.inside_tool_call = False
|
||||
self.has_tool_name_defined = False
|
||||
self.arg_nesting_level = 0
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServeArguments:
|
||||
r"""
|
||||
Arguments for the serve CLI.
|
||||
|
||||
See the metadata arg for each argument's description -- the metadata will be printed with
|
||||
`transformers serve --help`
|
||||
"""
|
||||
|
||||
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, "
|
||||
"the dtype will be automatically derived from the model's weights.",
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
return ServeCommand(nlp, args.host, args.port, args.workers)
|
||||
trust_remote_code: bool = field(
|
||||
default=False, metadata={"help": "Whether to trust remote code when loading a model."}
|
||||
)
|
||||
attn_implementation: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in "
|
||||
"which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
|
||||
},
|
||||
)
|
||||
load_in_8bit: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 8 bit precision for the base model - works only with LoRA."},
|
||||
)
|
||||
load_in_4bit: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 4 bit precision for the base model - works only with LoRA."},
|
||||
)
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
|
||||
|
||||
# Serving settings
|
||||
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
|
||||
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
|
||||
|
||||
class ServeModelInfoResult(BaseModel):
|
||||
"""
|
||||
Expose model information
|
||||
"""
|
||||
|
||||
infos: dict
|
||||
|
||||
|
||||
class ServeTokenizeResult(BaseModel):
|
||||
"""
|
||||
Tokenize result model
|
||||
"""
|
||||
|
||||
tokens: list[str]
|
||||
tokens_ids: Optional[list[int]]
|
||||
|
||||
|
||||
class ServeDeTokenizeResult(BaseModel):
|
||||
"""
|
||||
DeTokenize result model
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class ServeForwardResult(BaseModel):
|
||||
"""
|
||||
Forward result model
|
||||
"""
|
||||
|
||||
output: Any
|
||||
# Other settings
|
||||
log_level: str = field(
|
||||
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
|
||||
)
|
||||
|
||||
|
||||
class ServeCommand(BaseTransformersCLICommand):
|
||||
loaded_model: Optional[str] = None
|
||||
model: PreTrainedModel
|
||||
tokenizer: PreTrainedTokenizerFast
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
@ -98,131 +224,409 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
Args:
|
||||
parser: Root parser to register command-specific arguments
|
||||
"""
|
||||
serve_parser = parser.add_parser(
|
||||
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
choices=get_supported_tasks(),
|
||||
help="The task to run the pipeline on",
|
||||
)
|
||||
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
||||
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
||||
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
||||
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
||||
serve_parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||
)
|
||||
dataclass_types = (ServeArguments,)
|
||||
serve_parser = parser.add_parser("serve", dataclass_types=dataclass_types)
|
||||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
||||
self._pipeline = pipeline
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.workers = workers
|
||||
|
||||
if not _serve_dependencies_installed:
|
||||
raise RuntimeError(
|
||||
"Using serve command requires FastAPI and uvicorn. "
|
||||
'Please install transformers with [serving]: pip install "transformers[serving]". '
|
||||
"Or install FastAPI and uvicorn separately."
|
||||
def __init__(self, args: ServeArguments):
|
||||
if not is_pydantic_available() or not is_fastapi_available() or not is_uvicorn_available():
|
||||
raise ImportError(
|
||||
"Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Serving model over {host}:{port}")
|
||||
self._app = FastAPI(
|
||||
routes=[
|
||||
APIRoute(
|
||||
"/",
|
||||
self.model_info,
|
||||
response_model=ServeModelInfoResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["GET"],
|
||||
|
||||
self.args = args
|
||||
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
|
||||
|
||||
# State: preserves information about the last call and last KV cache, to determine whether we can reuse the KV
|
||||
# cache and avoid re-running prefil
|
||||
self.last_messages = None
|
||||
self.last_kv_cache = None
|
||||
|
||||
transformers_logger = logging.get_logger("transformers")
|
||||
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
|
||||
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
|
||||
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
|
||||
def build_chunk(
|
||||
self,
|
||||
content: str,
|
||||
request_id: str,
|
||||
role: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
|
||||
) -> str:
|
||||
payload = {
|
||||
"object": "chat.completion.chunk",
|
||||
"id": request_id,
|
||||
"created": int(time.time()),
|
||||
"model": self.loaded_model,
|
||||
"system_fingerprint": "",
|
||||
"choices": [
|
||||
ChatCompletionStreamOutputChoice(
|
||||
delta=ChatCompletionStreamOutputDelta(
|
||||
role=role,
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
APIRoute(
|
||||
"/tokenize",
|
||||
self.tokenize,
|
||||
response_model=ServeTokenizeResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
APIRoute(
|
||||
"/detokenize",
|
||||
self.detokenize,
|
||||
response_model=ServeDeTokenizeResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
APIRoute(
|
||||
"/forward",
|
||||
self.forward,
|
||||
response_model=ServeForwardResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
],
|
||||
timeout=600,
|
||||
)
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
],
|
||||
}
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
def run(self):
|
||||
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
||||
app = FastAPI()
|
||||
|
||||
def model_info(self):
|
||||
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
||||
if self.use_continuous_batching:
|
||||
self.continuous_batching(app)
|
||||
else:
|
||||
self.generate(app)
|
||||
|
||||
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_text_gen_models() -> list[ModelInfo]:
|
||||
"""
|
||||
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
|
||||
model working with generate can work.
|
||||
|
||||
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
|
||||
integrations.
|
||||
"""
|
||||
return [
|
||||
model_info("Menlo/Jan-nano"),
|
||||
model_info("Menlo/Jan-nano-128k"),
|
||||
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
|
||||
model_info("Qwen/Qwen2.5-3B-Instruct"),
|
||||
model_info("Qwen/Qwen2.5-7B-Instruct"),
|
||||
model_info("Qwen/Qwen2.5-14B-Instruct"),
|
||||
model_info("meta-llama/Llama-3.1-8B-Instruct"),
|
||||
model_info("meta-llama/Llama-3.2-1B-Instruct"),
|
||||
model_info("meta-llama/Llama-3.3-70B-Instruct"),
|
||||
]
|
||||
|
||||
@app.get("/v1/models")
|
||||
def get_all_models():
|
||||
return JSONResponse(
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model.id,
|
||||
"object": "model",
|
||||
"crated": model.created_at.timestamp(),
|
||||
"owned_by": model.author,
|
||||
}
|
||||
for model in get_text_gen_models()
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
||||
|
||||
def continuous_batching(self, app):
|
||||
generation_config = GenerationConfig(
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=1,
|
||||
block_size=1024,
|
||||
do_sample=False,
|
||||
max_batch_tokens=10,
|
||||
scheduler="fifo",
|
||||
)
|
||||
|
||||
manager: ContinuousBatchingManager = self.model.init_continuous_batching(
|
||||
generation_config=generation_config, streaming=True
|
||||
)
|
||||
manager.start()
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: "ChatCompletionInput"):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
update_model = req.model != self.loaded_model
|
||||
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
|
||||
chat = req.messages
|
||||
inputs = self.tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
self.model.device
|
||||
)
|
||||
|
||||
generation_config = create_generation_config_from_req(req)
|
||||
|
||||
def stream_response(_inputs):
|
||||
try:
|
||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||
request_id = manager.add_request(_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens)
|
||||
queue_is_flushed = False
|
||||
|
||||
for result in manager:
|
||||
if req.request_id is not None and not queue_is_flushed:
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
continue
|
||||
else:
|
||||
queue_is_flushed = True
|
||||
|
||||
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
|
||||
yield self.build_chunk(result.next_token, request_id=request_id, finish_reason=finish_reason)
|
||||
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
break
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
yield f'data: {{"error": "{str(e)}"}}'
|
||||
|
||||
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
|
||||
|
||||
def is_continuation(self, req: "ChatCompletionInput") -> bool:
|
||||
"""
|
||||
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
|
||||
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
|
||||
mapping.
|
||||
"""
|
||||
try:
|
||||
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
||||
Determines whether the current request is a continuation of the last request. In other words, if it is the
|
||||
same chat session.
|
||||
|
||||
if return_ids:
|
||||
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
||||
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
|
||||
Args:
|
||||
req (`ChatCompletionInput`): The request to check.
|
||||
|
||||
Returns:
|
||||
`True` if the request is a continuation of the last request, `False` otherwise.
|
||||
"""
|
||||
req_continues_last_messages = True
|
||||
|
||||
# No cached messages: this is a new request
|
||||
if self.last_messages is None:
|
||||
req_continues_last_messages = False
|
||||
# The new request has fewer rounds of conversation: this is a new request
|
||||
elif len(self.last_messages) > len(req.messages):
|
||||
req_continues_last_messages = False
|
||||
# Otherwise, check that the last messages are a subset of the new request
|
||||
else:
|
||||
for i in range(len(self.last_messages)):
|
||||
if self.last_messages[i] != req.messages[i]:
|
||||
req_continues_last_messages = False
|
||||
break
|
||||
|
||||
self.last_messages = req.messages
|
||||
return req_continues_last_messages
|
||||
|
||||
def generate(self, app):
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: "ChatCompletionInput"):
|
||||
update_model = req.model != self.loaded_model
|
||||
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
||||
# request whose last message is from the assistant.
|
||||
if req.messages[-1].role == "assistant":
|
||||
return
|
||||
|
||||
# ====== TOOL PREPROCESSING LOGIC ======
|
||||
tool_model_family = None
|
||||
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
|
||||
if supported_model_families in self.model.config.architectures[0].lower():
|
||||
tool_model_family = supported_model_families
|
||||
break
|
||||
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
|
||||
# 1. force generation to pick from the tool names
|
||||
# 2. force generation to pick from that tool's arguments
|
||||
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
||||
|
||||
if tool_model_family is not None:
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
|
||||
)
|
||||
else:
|
||||
return ServeTokenizeResult(tokens=tokens_txt)
|
||||
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
||||
request_id = req.request_id if req.request_id is not None else "req_0"
|
||||
|
||||
def detokenize(
|
||||
self,
|
||||
tokens_ids: list[int] = Body(None, embed=True),
|
||||
skip_special_tokens: bool = Body(False, embed=True),
|
||||
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
||||
):
|
||||
"""
|
||||
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
|
||||
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
|
||||
Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
||||
"""
|
||||
try:
|
||||
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||
return ServeDeTokenizeResult(model="", text=decoded_str)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
|
||||
async def forward(self, inputs=Body(None, embed=True)):
|
||||
"""
|
||||
**inputs**: **attention_mask**: **tokens_type_ids**:
|
||||
"""
|
||||
generation_config = create_generation_config_from_req(req)
|
||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||
generation_config.max_new_tokens = max_new_tokens
|
||||
|
||||
# Check we don't have empty string
|
||||
if len(inputs) == 0:
|
||||
return ServeForwardResult(output=[], attention=[])
|
||||
last_kv_cache = None
|
||||
if self.is_continuation(req) and not update_model:
|
||||
last_kv_cache = self.last_kv_cache
|
||||
|
||||
try:
|
||||
# Forward through the model
|
||||
output = self._pipeline(inputs)
|
||||
return ServeForwardResult(output=output)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, {"error": str(e)})
|
||||
generation_kwargs = {
|
||||
"inputs": inputs,
|
||||
"attention_mask": torch.ones_like(inputs),
|
||||
"streamer": generation_streamer,
|
||||
"generation_config": generation_config,
|
||||
"return_dict_in_generate": True,
|
||||
"past_key_values": last_kv_cache,
|
||||
}
|
||||
|
||||
def stream_response(streamer, _request_id):
|
||||
# Thin wrapper to save the KV cache after generation
|
||||
def generate_with_cache(**kwargs):
|
||||
generate_output = self.model.generate(**kwargs)
|
||||
self.last_kv_cache = generate_output.past_key_values
|
||||
|
||||
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
||||
|
||||
try:
|
||||
thread.start()
|
||||
tool_state = ToolState()
|
||||
|
||||
for result in streamer:
|
||||
# ====== TOOL CALL LOGIC ======
|
||||
if tool_model_family is not None:
|
||||
# Start of a tool call: reset state variables, set `inside_tool_call`
|
||||
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
|
||||
tool_state.inside_tool_call = True
|
||||
continue
|
||||
|
||||
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
|
||||
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
|
||||
tool_state.reset()
|
||||
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
|
||||
continue
|
||||
|
||||
# Inside a tool call
|
||||
if tool_state.inside_tool_call:
|
||||
tool_state.buffer += result
|
||||
|
||||
# First step: extract the tool name (may need several tokens, and we can't emit a delta
|
||||
# until we have the full name)
|
||||
if not tool_state.has_tool_name_defined:
|
||||
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
|
||||
if tool_name is None:
|
||||
continue
|
||||
else:
|
||||
tool_name = tool_name.group(1)
|
||||
tool_state.has_tool_name_defined = True
|
||||
tool = ChatCompletionStreamOutputDeltaToolCall(
|
||||
function=ChatCompletionStreamOutputFunction(
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
),
|
||||
index=0,
|
||||
type="function",
|
||||
id=_request_id + "_tool_call", # Only the first tool call delta has an id
|
||||
)
|
||||
|
||||
# Second step: extract tool arguments. The tool arguments can be seen as a json string
|
||||
# within the tool json string. We emit a delta for the arguments.
|
||||
else:
|
||||
# Empty text: skip
|
||||
if result == "":
|
||||
continue
|
||||
# Until we see the `"arguments": {` in the buffer, we skip
|
||||
# TODO: other models will likely need more elaborate processing here
|
||||
if '"arguments": {' not in tool_state.buffer:
|
||||
continue
|
||||
|
||||
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
|
||||
# closing the outermost nesting level, outside the arguments block)
|
||||
tool_state.arg_nesting_level += result.count("{")
|
||||
tool_state.arg_nesting_level -= result.count("}")
|
||||
if tool_state.arg_nesting_level < 0:
|
||||
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
|
||||
|
||||
tool = ChatCompletionStreamOutputDeltaToolCall(
|
||||
function=ChatCompletionStreamOutputFunction(
|
||||
arguments=result,
|
||||
),
|
||||
index=0,
|
||||
type="function",
|
||||
id=None,
|
||||
)
|
||||
|
||||
yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
|
||||
continue
|
||||
# ====== END OF TOOL CALL LOGIC ======
|
||||
|
||||
# All non-tool related tokens are emitted as assistant messages
|
||||
yield self.build_chunk(result, _request_id, role="assistant")
|
||||
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
|
||||
|
||||
thread.join()
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise
|
||||
yield f'data: {{"error": "{str(e)}"}}'
|
||||
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
|
||||
|
||||
@staticmethod
|
||||
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
|
||||
if model_args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
# For consistency with model weights, we use the same value as `torch_dtype`
|
||||
bnb_4bit_compute_dtype=model_args.torch_dtype,
|
||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||
bnb_4bit_quant_storage=model_args.torch_dtype,
|
||||
)
|
||||
elif model_args.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
return quantization_config
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
self, model_id_and_revision: str, args: ServeArguments
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
||||
logger.warning(f"Loading {model_id_and_revision}")
|
||||
|
||||
if "@" in model_id_and_revision:
|
||||
model_id, revision = model_id_and_revision.split("@", 1)
|
||||
else:
|
||||
model_id, revision = model_id_and_revision, "main"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||
quantization_config = self.get_quantization_config(args)
|
||||
|
||||
model_kwargs = {
|
||||
"revision": revision,
|
||||
"attn_implementation": args.attn_implementation,
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": "auto",
|
||||
"quantization_config": quantization_config,
|
||||
"trust_remote_code": args.trust_remote_code,
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
|
||||
|
||||
if model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 256:
|
||||
model.generation_config.max_new_tokens = 256
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
self.loaded_model = model_id_and_revision
|
||||
|
||||
print("Loaded model", model_id_and_revision)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve = ServeCommand()
|
||||
serve.run()
|
||||
|
@ -54,7 +54,7 @@ deps = {
|
||||
"protobuf": "protobuf",
|
||||
"psutil": "psutil",
|
||||
"pyyaml": "pyyaml>=5.1",
|
||||
"pydantic": "pydantic",
|
||||
"pydantic": "pydantic>=2",
|
||||
"pytest": "pytest>=7.2.0",
|
||||
"pytest-asyncio": "pytest-asyncio",
|
||||
"pytest-rerunfailures": "pytest-rerunfailures",
|
||||
|
@ -27,6 +27,8 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from torch.profiler import profile, schedule, tensorboard_trace_handler
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -72,6 +74,7 @@ class GenerationOutput:
|
||||
error: Optional[str] = None
|
||||
status: RequestStatus = RequestStatus.PENDING
|
||||
created_time: float = field(default_factory=time.time)
|
||||
next_token: Optional[int] = field(default_factory=int)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -96,6 +99,7 @@ class RequestState:
|
||||
eos_token_id: int = -1
|
||||
created_time: float = field(default_factory=time.time)
|
||||
error: Optional[str] = None
|
||||
next_token: Optional[str] = None
|
||||
|
||||
def current_len(self) -> int:
|
||||
"""Get the current length of the sequence (prompt + generated tokens)."""
|
||||
@ -139,6 +143,7 @@ class RequestState:
|
||||
generated_tokens=self.static_outputs,
|
||||
logprobs=[],
|
||||
error=self.error,
|
||||
next_token=self.next_token,
|
||||
)
|
||||
|
||||
|
||||
@ -764,6 +769,9 @@ class ContinuousBatchProcessor:
|
||||
|
||||
self.setup_static_tensors()
|
||||
|
||||
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
|
||||
self.decode_stream = DecodeStream(skip_special_tokens=True)
|
||||
|
||||
@traced(standalone=True)
|
||||
def setup_static_tensors(self):
|
||||
T = self.max_batch_tokens
|
||||
@ -995,7 +1003,7 @@ class ContinuousBatchProcessor:
|
||||
def _maybe_send_output(self, state: RequestState, token: int):
|
||||
"""Send output to the queue based on streaming mode and request state."""
|
||||
if self.streaming:
|
||||
state.next_token = token
|
||||
state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1])
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
elif state.status == RequestStatus.FINISHED:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
@ -1102,6 +1110,7 @@ class ContinuousBatchingManager:
|
||||
self.profile = getattr(generation_config, "profile", False)
|
||||
self.manual_eviction = manual_eviction
|
||||
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
||||
self.decode_stream = DecodeStream(skip_special_tokens=True)
|
||||
|
||||
@traced
|
||||
def start(self):
|
||||
|
@ -733,7 +733,9 @@ class GenerationMixin(ContinuousMixin):
|
||||
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
|
||||
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
|
||||
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
|
||||
if not self.config.is_encoder_decoder:
|
||||
if model_kwargs["inputs_embeds"] is None:
|
||||
model_kwargs.pop("inputs_embeds")
|
||||
elif not self.config.is_encoder_decoder:
|
||||
has_inputs_embeds_forwarding = "inputs_embeds" in set(
|
||||
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
|
||||
)
|
||||
@ -748,10 +750,11 @@ class GenerationMixin(ContinuousMixin):
|
||||
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
|
||||
inputs, bos_token_id, model_kwargs=model_kwargs
|
||||
)
|
||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
else:
|
||||
if inputs is not None:
|
||||
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
|
||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||
|
||||
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
||||
|
@ -34,6 +34,10 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
|
||||
|
||||
if os.getenv("WANDB_MODE") == "offline":
|
||||
print("⚙️ Running in WANDB offline mode")
|
||||
|
||||
from .. import PreTrainedModel, TFPreTrainedModel, TrainingArguments
|
||||
from .. import __version__ as version
|
||||
from ..utils import (
|
||||
@ -860,7 +864,7 @@ class WandbCallback(TrainerCallback):
|
||||
**init_args,
|
||||
)
|
||||
# add config parameters (run may have been created manually)
|
||||
self._wandb.config.update(combined_dict, allow_val_change=True)
|
||||
self._wandb.config.update(combined_dict or {}, allow_val_change=True)
|
||||
|
||||
# define default x-axis (for latest wandb versions)
|
||||
if getattr(self._wandb, "define_metric", None):
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
@ -57,10 +58,12 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "ccl", "hpu": "hccl"}
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
|
||||
backend = backend_map.get(device_type)
|
||||
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", 0)):
|
||||
backend = "ccl"
|
||||
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
|
||||
backend = "ccl"
|
||||
|
||||
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
current_device = getattr(torch, device_type)
|
||||
@ -278,7 +281,48 @@ def repack_weights(
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
"""
|
||||
Generalized tensor sharding across a multi-dimensional device mesh.
|
||||
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
|
||||
Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
|
||||
`Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
|
||||
such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
|
||||
|
||||
Case (1)
|
||||
empty_param (16, 5120, 8190)
|
||||
dim 0
|
||||
device_mesh.size() 4
|
||||
rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
|
||||
rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
|
||||
rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
|
||||
rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
|
||||
|
||||
Case (2)
|
||||
empty_param (16, 5120, 8190)
|
||||
dim 0
|
||||
device_mesh.size() 14
|
||||
rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
|
||||
rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
|
||||
rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
|
||||
rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
|
||||
rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
|
||||
rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
|
||||
rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
|
||||
rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
|
||||
rank 8 gets (0, 5120, 8190)
|
||||
rank 9 gets (0, 5120, 8190)
|
||||
rank 10 gets (0, 5120, 8190)
|
||||
rank 11 gets (0, 5120, 8190)
|
||||
rank 12 gets (0, 5120, 8190)
|
||||
rank 13 gets (0, 5120, 8190)
|
||||
|
||||
Case (3)
|
||||
empty_param (16, 5120, 8190)
|
||||
dim 0
|
||||
device_mesh.size() 3
|
||||
rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
|
||||
rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
|
||||
rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
|
||||
|
||||
In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
|
||||
Args:
|
||||
param (torch.Tensor): The tensor to shard.
|
||||
empty_param (torch.Tensor): A tensor used for shape reference.
|
||||
@ -287,6 +331,7 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
dim (int): Dimension along which to shard the tensor.
|
||||
"""
|
||||
param_dim = empty_param.dim()
|
||||
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if dim >= param_dim:
|
||||
@ -299,15 +344,18 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
if rank >= world_size:
|
||||
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
||||
|
||||
shard_size = empty_param.shape[dim] // world_size
|
||||
shard_size = math.ceil(empty_param.shape[dim] / world_size)
|
||||
start = rank * shard_size
|
||||
end = start + shard_size
|
||||
|
||||
# Construct slicing index dynamically
|
||||
end = min(start + shard_size, empty_param.shape[dim])
|
||||
slice_indices = [slice(None)] * param_dim
|
||||
slice_indices[dim] = slice(start, end)
|
||||
|
||||
return param[tuple(slice_indices)]
|
||||
if start < empty_param.shape[dim]:
|
||||
slice_indices[dim] = slice(start, end)
|
||||
return param[tuple(slice_indices)]
|
||||
dimensions = list(param.shape)
|
||||
dimensions[dim] = 0
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64)
|
||||
|
||||
|
||||
def distribute_module(
|
||||
@ -498,7 +546,9 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
parameter = DTensor.from_local(
|
||||
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
|
||||
)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
@staticmethod
|
||||
@ -572,7 +622,9 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
parameter = DTensor.from_local(
|
||||
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
|
||||
)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
@staticmethod
|
||||
|
@ -508,6 +508,22 @@ def _flash_attention_forward(
|
||||
query_states, key_states, value_states, target_dtype
|
||||
)
|
||||
|
||||
# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
# under two cases:
|
||||
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
|
||||
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
|
||||
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
|
||||
is_fa2_with_position_ids = (
|
||||
position_ids is not None
|
||||
and query_states.shape[0] == 1
|
||||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
||||
)
|
||||
is_fa2_with_varlen_kwargs = all(
|
||||
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
|
||||
)
|
||||
|
||||
# Contains at least one padding token in the sequence
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
@ -531,14 +547,7 @@ def _flash_attention_forward(
|
||||
)
|
||||
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
|
||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
elif (
|
||||
position_ids is not None
|
||||
and query_states.shape[0] == 1
|
||||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
||||
):
|
||||
elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
|
||||
batch_size = query_states.size(0)
|
||||
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
|
@ -3746,7 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
module_map[name + f".{key}"] = module
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
|
||||
if any(
|
||||
allowed_name in class_name.__name__.lower()
|
||||
for class_name in self.__class__.__mro__[:-1]
|
||||
for allowed_name in VLMS
|
||||
):
|
||||
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
|
||||
|
||||
original_state_dict = {}
|
||||
@ -4402,7 +4406,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
|
||||
if key_mapping is None and any(
|
||||
allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
|
||||
):
|
||||
key_mapping = cls._checkpoint_conversion_mapping
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
@ -5837,7 +5843,12 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict,
|
||||
else None
|
||||
)
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
tied_param_names = _get_tied_weight_keys(model)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
# Skip if the parameter has already been accounted for (tied weights)
|
||||
if param_name in tied_param_names:
|
||||
continue
|
||||
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,14 +25,15 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithNoAttention,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndNoAttention,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig
|
||||
|
||||
|
||||
@ -90,7 +91,7 @@ class AlignOutput(ModelOutput):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The output of [`AlignVisionModel`].
|
||||
text_model_output (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
||||
text_model_output (`BaseModelOutputWithPooling`):
|
||||
The output of the [`AlignTextModel`].
|
||||
vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`):
|
||||
The output of the [`AlignVisionModel`].
|
||||
@ -101,7 +102,7 @@ class AlignOutput(ModelOutput):
|
||||
logits_per_text: Optional[torch.FloatTensor] = None
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
|
||||
text_model_output: BaseModelOutputWithPooling = None
|
||||
vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None
|
||||
|
||||
def to_tuple(self) -> tuple[Any]:
|
||||
@ -508,7 +509,6 @@ class AlignVisionEncoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->AlignText
|
||||
class AlignTextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
@ -537,7 +537,6 @@ class AlignTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -547,7 +546,7 @@ class AlignTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
@ -573,9 +572,35 @@ class AlignTextEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->AlignText
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AlignTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -583,6 +608,7 @@ class AlignTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -592,20 +618,12 @@ class AlignTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -615,96 +633,33 @@ class AlignTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in AlignTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -723,18 +678,10 @@ class AlignTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": AlignTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
|
||||
class AlignTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = AlignTextSelfAttention(config)
|
||||
self.output = AlignTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -756,6 +703,9 @@ class AlignTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -765,15 +715,14 @@ class AlignTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -811,22 +760,18 @@ class AlignTextOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
|
||||
class AlignTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = AlignTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = AlignTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = AlignTextIntermediate(config)
|
||||
self.output = AlignTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -836,60 +781,23 @@ class AlignTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -898,14 +806,18 @@ class AlignTextLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->AlignText
|
||||
class AlignTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -918,65 +830,36 @@ class AlignTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1052,6 +935,7 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1059,12 +943,13 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Examples:
|
||||
|
||||
@ -1133,20 +1018,17 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1180,6 +1062,7 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.convolution
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1219,7 +1102,7 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
# Apply pooling
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
@ -1227,9 +1110,6 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
# Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim)
|
||||
pooled_output = pooled_output.reshape(pooled_output.shape[:2])
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndNoAttention(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1369,6 +1249,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1419,7 +1300,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1431,7 +1312,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
@ -1450,10 +1331,6 @@ class AlignModel(AlignPreTrainedModel):
|
||||
if return_loss:
|
||||
loss = align_loss(logits_per_text)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return AlignOutput(
|
||||
loss=loss,
|
||||
logits_per_image=logits_per_image,
|
||||
|
@ -26,14 +26,14 @@ from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndProjection,
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
|
||||
|
||||
|
||||
@ -180,7 +180,6 @@ class AltRobertaEmbeddings(nn.Module):
|
||||
return position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->AltRoberta
|
||||
class AltRobertaSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
@ -206,13 +205,9 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -223,55 +218,19 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
@ -310,8 +269,6 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -335,7 +292,6 @@ ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA
|
||||
class AltRobertaAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
@ -363,6 +319,9 @@ class AltRobertaAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -375,12 +334,9 @@ class AltRobertaAttention(nn.Module):
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -418,22 +374,19 @@ class AltRobertaOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->AltRoberta
|
||||
class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = AltRobertaAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = AltRobertaIntermediate(config)
|
||||
self.output = AltRobertaOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -443,60 +396,23 @@ class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -505,14 +421,19 @@ class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->AltRoberta
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta
|
||||
class AltRobertaEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([AltRobertaLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -525,65 +446,36 @@ class AltRobertaEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -787,6 +679,7 @@ class AltCLIPEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -853,8 +746,6 @@ class AltCLIPEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
@ -1008,6 +899,7 @@ class AltCLIPVisionTransformer(nn.Module):
|
||||
self.encoder = AltCLIPEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1033,16 +925,13 @@ class AltCLIPVisionTransformer(nn.Module):
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1106,16 +995,11 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel):
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
|
||||
The model behaves as an encoder following the architecture described in *Attention is
|
||||
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
|
||||
.. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
|
||||
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
||||
"""
|
||||
)
|
||||
class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
@ -1152,6 +1036,10 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
@ -1176,11 +1064,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1194,11 +1077,8 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
@ -1212,21 +1092,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
@ -1235,33 +1100,23 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1284,6 +1139,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
|
||||
return super().resize_token_embeddings(new_num_tokens)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1326,11 +1184,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# last module outputs
|
||||
@ -1343,9 +1199,6 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
projection_state = self.transformation(sequence_output)
|
||||
pooler_output = projection_state[:, 0]
|
||||
|
||||
if not return_dict:
|
||||
return (projection_state, pooler_output) + outputs[2:4]
|
||||
|
||||
return BaseModelOutputWithPoolingAndProjection(
|
||||
last_hidden_state=projection_state,
|
||||
pooler_output=pooler_output,
|
||||
|
@ -225,7 +225,7 @@ class ArceeAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -532,7 +532,7 @@ class AriaTextAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -1113,11 +1113,12 @@ class AriaModel(AriaPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
image_embeds = input_ids == self.config.image_token_id
|
||||
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
pixel_mask=pixel_mask,
|
||||
|
@ -1446,11 +1446,12 @@ class AriaModel(LlavaModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
image_embeds = input_ids == self.config.image_token_id
|
||||
special_image_mask = image_embeds.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (image_embeds).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
pixel_mask=pixel_mask,
|
||||
|
@ -302,14 +302,14 @@ class AyaVisionModel(AyaVisionPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -223,14 +223,14 @@ class AyaVisionModel(LlavaModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -296,7 +296,7 @@ class BambaAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -1855,6 +1855,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
|
||||
_keep_in_fp32_modules = ["query_tokens", "qformer"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
@ -1971,10 +1972,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.FloatTensor,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -2066,14 +2068,25 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
language_model_attention_mask = torch.ones(
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concating
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
@ -2146,6 +2159,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
@ -2159,6 +2173,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
Returns:
|
||||
captions (list): A list of strings of length batch_size * num_captions.
|
||||
@ -2193,22 +2211,32 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
||||
|
@ -1026,7 +1026,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
@ -26,13 +26,14 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_bros import BrosConfig
|
||||
|
||||
|
||||
@ -150,7 +151,6 @@ class BrosTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -160,7 +160,7 @@ class BrosTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self, "token_type_ids"):
|
||||
@ -208,14 +208,7 @@ class BrosSelfAttention(nn.Module):
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor):
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -227,42 +220,21 @@ class BrosSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[torch.Tensor] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
if is_cross_attention:
|
||||
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
@ -317,7 +289,7 @@ class BrosSelfAttention(nn.Module):
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -364,6 +336,7 @@ class BrosAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -382,7 +355,6 @@ class BrosAttention(nn.Module):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
@ -435,6 +407,7 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
self.intermediate = BrosIntermediate(config)
|
||||
self.output = BrosOutput(config)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -446,50 +419,38 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
bbox_pos_emb=bbox_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if hasattr(self, "crossattention"):
|
||||
raise Exception(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
@ -500,7 +461,7 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -516,6 +477,9 @@ class BrosEncoder(nn.Module):
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -529,33 +493,28 @@ class BrosEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
bbox_pos_emb,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
hidden_states=hidden_states,
|
||||
bbox_pos_emb=bbox_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
@ -564,21 +523,8 @@ class BrosEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
@ -689,6 +635,9 @@ class BrosModel(BrosPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -736,11 +685,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -756,9 +700,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
|
||||
@ -797,7 +738,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
|
||||
@ -813,22 +753,16 @@ class BrosModel(BrosPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
@ -852,6 +786,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -908,7 +843,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -927,10 +862,6 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -976,6 +907,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1037,7 +969,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_states = outputs[0]
|
||||
@ -1082,10 +1014,6 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
loss = initial_token_loss + subsequent_token_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (initial_token_logits, subsequent_token_logits) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return BrosSpadeOutput(
|
||||
loss=loss,
|
||||
initial_token_logits=initial_token_logits,
|
||||
@ -1118,6 +1046,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1173,7 +1102,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_states = outputs[0]
|
||||
@ -1203,10 +1132,6 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
@ -963,25 +963,28 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
|
||||
n_image_tokens_in_text = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel():
|
||||
n_image_features = image_embeds.shape[0] * image_embeds.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
|
||||
)
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
|
||||
|
||||
# torch.jit.trace() doesn't support cache objects in the output
|
||||
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Chinese-CLIP model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -26,13 +25,13 @@ from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
|
||||
|
||||
|
||||
@ -90,7 +89,7 @@ class ChineseCLIPOutput(ModelOutput):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEmbeddings with Align->ChineseCLIP
|
||||
class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
@ -119,7 +118,6 @@ class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -129,7 +127,7 @@ class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
@ -239,9 +237,37 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->ChineseCLIP
|
||||
class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -249,6 +275,7 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -258,20 +285,12 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -281,96 +300,33 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -389,18 +345,11 @@ class ChineseCLIPTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ChineseCLIPTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->ChineseCLIP
|
||||
class ChineseCLIPTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = ChineseCLIPTextSelfAttention(config)
|
||||
self.output = ChineseCLIPTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -422,6 +371,9 @@ class ChineseCLIPTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -431,15 +383,14 @@ class ChineseCLIPTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -468,66 +419,37 @@ class ChineseCLIPVisionAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.scale
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
None,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText
|
||||
@ -577,22 +499,19 @@ class ChineseCLIPVisionMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->ChineseCLIP
|
||||
class ChineseCLIPTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ChineseCLIPTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = ChineseCLIPTextIntermediate(config)
|
||||
self.output = ChineseCLIPTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -602,60 +521,23 @@ class ChineseCLIPTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -777,14 +659,19 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP
|
||||
class ChineseCLIPTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -797,65 +684,36 @@ class ChineseCLIPTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -874,6 +732,7 @@ class ChineseCLIPVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -922,8 +781,6 @@ class ChineseCLIPVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
@ -940,6 +797,7 @@ class ChineseCLIPVisionTransformer(nn.Module):
|
||||
self.encoder = ChineseCLIPVisionEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -965,16 +823,13 @@ class ChineseCLIPVisionTransformer(nn.Module):
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1034,6 +889,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1050,18 +906,13 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1093,56 +944,28 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1343,6 +1166,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1392,7 +1216,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1402,7 +1226,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
@ -1424,14 +1248,6 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
if return_loss:
|
||||
loss = chinese_clip_loss(logits_per_text)
|
||||
|
||||
if not return_dict:
|
||||
# fix the None pooled_output of text_outputs to conform with dict_output
|
||||
pooled_output = text_outputs[1]
|
||||
if pooled_output is None:
|
||||
text_outputs = (text_outputs[0],) + text_outputs[2:]
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ChineseCLIPOutput(
|
||||
loss=loss,
|
||||
logits_per_image=logits_per_image,
|
||||
|
@ -17,7 +17,7 @@
|
||||
import collections
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -26,13 +26,14 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
|
||||
|
||||
|
||||
@ -399,11 +400,6 @@ class ClapAudioSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -412,11 +408,11 @@ class ClapAudioSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
@ -1090,9 +1086,37 @@ class ClapTextEmbeddings(nn.Module):
|
||||
return position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap
|
||||
class ClapTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -1100,6 +1124,7 @@ class ClapTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -1109,20 +1134,12 @@ class ClapTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1132,96 +1149,33 @@ class ClapTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in ClapTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -1240,18 +1194,11 @@ class ClapTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
CLAP_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ClapTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap
|
||||
class ClapTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = ClapTextSelfAttention(config)
|
||||
self.output = ClapTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -1273,6 +1220,9 @@ class ClapTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1282,15 +1232,14 @@ class ClapTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -1328,22 +1277,19 @@ class ClapTextOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap
|
||||
class ClapTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ClapTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = ClapTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = ClapTextIntermediate(config)
|
||||
self.output = ClapTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1353,60 +1299,23 @@ class ClapTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -1415,14 +1324,19 @@ class ClapTextLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap
|
||||
class ClapTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ClapTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1435,65 +1349,36 @@ class ClapTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1643,6 +1528,11 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1666,11 +1556,6 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1684,11 +1569,8 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
@ -1702,21 +1584,6 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
@ -1725,33 +1592,23 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1892,6 +1749,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
|
||||
return audio_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1947,7 +1805,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
is_longer=is_longer,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1956,7 +1814,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
|
||||
@ -1981,10 +1839,6 @@ class ClapModel(ClapPreTrainedModel):
|
||||
audio_loss = contrastive_loss(logits_per_audio.t())
|
||||
loss = (caption_loss + audio_loss) / 2.0
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_audio, logits_per_text, text_embeds, audio_embeds, text_outputs, audio_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ClapOutput(
|
||||
loss=loss,
|
||||
logits_per_audio=logits_per_audio,
|
||||
@ -2013,6 +1867,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.text_model.embeddings.word_embeddings = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2045,17 +1900,13 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
|
||||
|
||||
text_embeds = self.text_projection(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
return ClapTextModelOutput(
|
||||
text_embeds=text_embeds,
|
||||
last_hidden_state=text_outputs.last_hidden_state,
|
||||
@ -2079,6 +1930,7 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.audio_model.audio_encoder.patch_embed.proj
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2123,17 +1975,13 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
|
||||
is_longer=is_longer,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
|
||||
|
||||
audio_embeds = self.audio_projection(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (audio_embeds, audio_outputs[0]) + audio_outputs[2:]
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
return ClapAudioModelOutput(
|
||||
audio_embeds=audio_embeds,
|
||||
last_hidden_state=audio_outputs.last_hidden_state,
|
||||
|
@ -28,7 +28,7 @@ from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
|
||||
|
||||
|
||||
@ -490,6 +490,7 @@ class CLIPSegEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -555,8 +556,6 @@ class CLIPSegEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -311,7 +311,7 @@ class CsmAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -45,6 +45,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_data2vec_audio import Data2VecAudioConfig
|
||||
|
||||
|
||||
@ -240,6 +241,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -247,7 +249,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -268,42 +270,9 @@ class Data2VecAudioAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -325,7 +294,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class Data2VecAudioFeedForward(nn.Module):
|
||||
|
@ -634,7 +634,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
@ -281,7 +281,7 @@ class DiaSelfAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -405,11 +405,6 @@ class DonutSwinSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -418,11 +413,11 @@ class DonutSwinSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
@ -189,7 +189,7 @@ class Emu3Attention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
@ -1537,20 +1537,26 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
image_embeds = self.get_image_features(pixel_values, image_sizes)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -1033,20 +1033,26 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values, image_sizes)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
image_embeds = self.get_image_features(pixel_values, image_sizes)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.text_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
|
@ -26,14 +26,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_esm import EsmConfig
|
||||
|
||||
|
||||
@ -187,12 +188,16 @@ class EsmEmbeddings(nn.Module):
|
||||
self.mask_token_id = config.mask_token_id
|
||||
|
||||
def forward(
|
||||
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
@ -281,11 +286,7 @@ class EsmSelfAttention(nn.Module):
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -296,32 +297,22 @@ class EsmSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
|
||||
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
if is_cross_attention:
|
||||
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
|
||||
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
|
||||
@ -329,16 +320,6 @@ class EsmSelfAttention(nn.Module):
|
||||
# ESM code and fix rotary embeddings.
|
||||
query_layer = query_layer * self.attention_head_size**-0.5
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
if self.position_embedding_type == "rotary":
|
||||
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
||||
|
||||
@ -385,7 +366,7 @@ class EsmSelfAttention(nn.Module):
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -418,6 +399,7 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -441,7 +423,6 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
@ -450,9 +431,6 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
if past_key_value is not None:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
@ -514,7 +492,7 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
|
||||
outputs = (attn_output, None)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -551,6 +529,7 @@ class EsmAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -564,12 +543,11 @@ class EsmAttention(nn.Module):
|
||||
hidden_states_ln = self.LayerNorm(hidden_states)
|
||||
self_outputs = self.self(
|
||||
hidden_states_ln,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -616,6 +594,7 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
self.output = EsmOutput(config)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -626,25 +605,20 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise AttributeError(
|
||||
@ -652,31 +626,24 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
" with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = self.feed_forward_chunk(attention_output)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -694,6 +661,9 @@ class EsmEncoder(nn.Module):
|
||||
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -707,38 +677,26 @@ class EsmEncoder(nn.Module):
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
||||
"`use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
@ -750,21 +708,8 @@ class EsmEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
@ -863,6 +808,9 @@ class EsmModel(EsmPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -903,11 +851,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -921,11 +864,8 @@ class EsmModel(EsmPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
extended_attention_mask = attention_mask
|
||||
@ -958,7 +898,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -966,22 +905,16 @@ class EsmModel(EsmPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
@ -1025,6 +958,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1058,7 +992,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
@ -1070,10 +1004,6 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
labels = labels.to(prediction_scores.device)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
@ -1125,6 +1055,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1154,7 +1085,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
@ -1184,10 +1115,6 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1210,6 +1137,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1237,7 +1165,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1252,10 +1180,6 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1283,7 +1207,7 @@ class EsmClassificationHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
@ -1295,7 +1219,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
|
@ -206,14 +206,22 @@ class FuyuModel(FuyuPreTrainedModel):
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if image_patches is not None and past_key_values is None:
|
||||
patch_embeddings = self.get_image_features(image_patches)
|
||||
patch_embeddings = torch.cat(patch_embeddings, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
|
||||
if image_patches is not None:
|
||||
patch_embeddings = self.get_image_features(image_patches)
|
||||
patch_embeddings = torch.cat(patch_embeddings, dim=0)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
|
||||
|
||||
outputs = self.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
@ -222,7 +222,7 @@ class GemmaAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -898,9 +898,11 @@ class Gemma3Model(Gemma3PreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
|
@ -800,9 +800,11 @@ class Gemma3Model(PaliGemmaModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
|
@ -301,10 +301,10 @@ class Gemma3nTextConfig(PretrainedConfig):
|
||||
|
||||
class Gemma3nAudioConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's
|
||||
[Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified
|
||||
arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
|
||||
configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`]. It is used to instantiate
|
||||
an `Gemma3nAudioEncoder` model according to the specified arguments, defining the model architecture. Instantiating
|
||||
a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.,
|
||||
[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||
|
||||
Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
|
||||
the documentation from [`Gemma3nAudioConfig`] for more information.
|
||||
|
@ -911,7 +911,7 @@ class Gemma3nAudioConformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
"""A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
|
||||
"""An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture."""
|
||||
|
||||
config_class = Gemma3nAudioConfig
|
||||
|
||||
@ -1135,9 +1135,17 @@ class Gemma3nTextAltUp(nn.Module):
|
||||
corrected += predictions # add the original input
|
||||
return corrected.contiguous().type_as(activated)
|
||||
|
||||
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
|
||||
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
|
||||
`scale_corrected_output`
|
||||
"""
|
||||
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
||||
|
||||
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
|
||||
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
||||
return self.forward(corrected)
|
||||
|
||||
|
||||
class Gemma3nTextRotaryEmbedding(nn.Module):
|
||||
@ -1290,7 +1298,7 @@ class Gemma3nTextAttention(nn.Module):
|
||||
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
|
||||
|
||||
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
|
||||
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
|
||||
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
||||
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
|
||||
layer_type = config.layer_types[layer_idx]
|
||||
self.kv_shared_layer_index = (
|
||||
@ -1319,21 +1327,22 @@ class Gemma3nTextAttention(nn.Module):
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
|
||||
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
|
||||
# Device of past layer may be different from current one
|
||||
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
|
||||
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
|
||||
if isinstance(past_key_value, HybridCache) and self.is_sliding:
|
||||
max_length = past_key_value.sliding_window
|
||||
if cache_position.shape[0] > max_length:
|
||||
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
|
||||
# slice into the entire cache.
|
||||
indices = slice(0, max_length)
|
||||
else:
|
||||
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
|
||||
indices = cache_position.clamp(min=0, max=max_length - 1)
|
||||
else:
|
||||
indices = cache_position
|
||||
indices = (
|
||||
slice(0, max_length)
|
||||
if cache_position.shape[0] > max_length
|
||||
else cache_position.clamp(min=0, max=max_length - 1)
|
||||
)
|
||||
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
|
||||
query_states.device
|
||||
)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
key_states = self.k_norm(key_states)
|
||||
@ -1447,10 +1456,9 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
||||
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
|
||||
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
|
||||
|
||||
first_prediction = corrected_predictions[self.config.altup_active_idx]
|
||||
first_prediction_clone = first_prediction.clone()
|
||||
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
|
||||
if self.config.altup_correct_scale:
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
||||
|
||||
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
|
||||
first_prediction = self.per_layer_input_gate(first_prediction)
|
||||
@ -1475,7 +1483,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
||||
config_class = Gemma3nConfig
|
||||
base_model_prefix = ""
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Gemma3nDecoderLayer"]
|
||||
_no_split_modules = ["Gemma3nTextDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
@ -1656,18 +1664,17 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
|
||||
|
||||
# Expand hidden_states to support per-layer inputs
|
||||
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
|
||||
epsilon_tensor = torch.tensor(torch.finfo().min)
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
|
||||
epsilon_tensor = torch.tensor(1e-5)
|
||||
|
||||
temp_hidden_states = [hidden_states_0]
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
|
||||
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
||||
current_hidden_state = current_hidden_state * (
|
||||
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
||||
)
|
||||
altup_proj = self.altup_projections[i - 1](hidden_states_0)
|
||||
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
||||
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
|
||||
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
|
||||
temp_hidden_states.append(current_hidden_state)
|
||||
|
||||
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
|
||||
@ -1685,9 +1692,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
per_layer_input=per_layer_input,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
per_layer_input,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
@ -1712,11 +1719,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
|
||||
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
||||
current_hidden_state = current_hidden_state * (
|
||||
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
||||
)
|
||||
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
||||
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
|
||||
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
|
||||
temp_hidden_states.append(current_hidden_state)
|
||||
|
||||
hidden_states = torch.stack(temp_hidden_states)
|
||||
@ -1743,7 +1749,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
|
||||
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
|
||||
per_layer_projection *= self.per_layer_projection_scale.to(
|
||||
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
||||
)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*inputs_embeds.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
@ -1758,7 +1766,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
|
||||
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
|
||||
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
|
||||
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
|
||||
|
@ -313,10 +313,10 @@ class Gemma3nTextConfig(Gemma2Config, PretrainedConfig):
|
||||
|
||||
class Gemma3nAudioConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's
|
||||
[Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified
|
||||
arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
|
||||
configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`]. It is used to instantiate
|
||||
an `Gemma3nAudioEncoder` model according to the specified arguments, defining the model architecture. Instantiating
|
||||
a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.,
|
||||
[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||
|
||||
Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
|
||||
the documentation from [`Gemma3nAudioConfig`] for more information.
|
||||
@ -1473,7 +1473,7 @@ class Gemma3nAudioConformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
"""A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
|
||||
"""An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture."""
|
||||
|
||||
config_class = Gemma3nAudioConfig
|
||||
|
||||
@ -1685,9 +1685,17 @@ class Gemma3nTextAltUp(nn.Module):
|
||||
corrected += predictions # add the original input
|
||||
return corrected.contiguous().type_as(activated)
|
||||
|
||||
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
|
||||
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
|
||||
`scale_corrected_output`
|
||||
"""
|
||||
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
||||
|
||||
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
|
||||
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
||||
return self.forward(corrected)
|
||||
|
||||
|
||||
class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding):
|
||||
@ -1732,7 +1740,7 @@ class Gemma3nTextAttention(Gemma3Attention):
|
||||
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
|
||||
|
||||
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
|
||||
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
|
||||
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
||||
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
|
||||
layer_type = config.layer_types[layer_idx]
|
||||
self.kv_shared_layer_index = (
|
||||
@ -1761,21 +1769,22 @@ class Gemma3nTextAttention(Gemma3Attention):
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
|
||||
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
|
||||
# Device of past layer may be different from current one
|
||||
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
|
||||
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
|
||||
if isinstance(past_key_value, HybridCache) and self.is_sliding:
|
||||
max_length = past_key_value.sliding_window
|
||||
if cache_position.shape[0] > max_length:
|
||||
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
|
||||
# slice into the entire cache.
|
||||
indices = slice(0, max_length)
|
||||
else:
|
||||
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
|
||||
indices = cache_position.clamp(min=0, max=max_length - 1)
|
||||
else:
|
||||
indices = cache_position
|
||||
indices = (
|
||||
slice(0, max_length)
|
||||
if cache_position.shape[0] > max_length
|
||||
else cache_position.clamp(min=0, max=max_length - 1)
|
||||
)
|
||||
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
|
||||
query_states.device
|
||||
)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
key_states = self.k_norm(key_states)
|
||||
@ -1880,10 +1889,9 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
|
||||
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
|
||||
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
|
||||
|
||||
first_prediction = corrected_predictions[self.config.altup_active_idx]
|
||||
first_prediction_clone = first_prediction.clone()
|
||||
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
|
||||
if self.config.altup_correct_scale:
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
||||
|
||||
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
|
||||
first_prediction = self.per_layer_input_gate(first_prediction)
|
||||
@ -1906,7 +1914,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
|
||||
class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
|
||||
config_class = Gemma3nConfig
|
||||
base_model_prefix = ""
|
||||
_no_split_modules = ["Gemma3nDecoderLayer"]
|
||||
_no_split_modules = ["Gemma3nTextDecoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Gemma2 isn't meant for training from scratch - only
|
||||
@ -1995,7 +2003,9 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
|
||||
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
|
||||
per_layer_projection *= self.per_layer_projection_scale.to(
|
||||
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
||||
)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*inputs_embeds.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
@ -2010,7 +2020,9 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
|
||||
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
|
||||
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
|
||||
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
||||
)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@ -2091,18 +2103,17 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
|
||||
|
||||
# Expand hidden_states to support per-layer inputs
|
||||
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
|
||||
epsilon_tensor = torch.tensor(torch.finfo().min)
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
|
||||
epsilon_tensor = torch.tensor(1e-5)
|
||||
|
||||
temp_hidden_states = [hidden_states_0]
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
|
||||
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
||||
current_hidden_state = current_hidden_state * (
|
||||
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
||||
)
|
||||
altup_proj = self.altup_projections[i - 1](hidden_states_0)
|
||||
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
||||
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
|
||||
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
|
||||
temp_hidden_states.append(current_hidden_state)
|
||||
|
||||
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
|
||||
@ -2120,9 +2131,9 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
per_layer_input=per_layer_input,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
per_layer_input,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
@ -2147,11 +2158,10 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
|
||||
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
||||
current_hidden_state = current_hidden_state * (
|
||||
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
||||
)
|
||||
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
||||
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
|
||||
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
|
||||
temp_hidden_states.append(current_hidden_state)
|
||||
|
||||
hidden_states = torch.stack(temp_hidden_states)
|
||||
|
@ -39,6 +39,7 @@ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
torch_int,
|
||||
)
|
||||
@ -770,6 +771,7 @@ class GitVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -836,8 +838,6 @@ class GitVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -184,7 +184,7 @@ class GlmAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -242,7 +242,7 @@ class Glm4Attention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -121,6 +121,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
do_convert_rgb: bool,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
disable_grouping: Optional[bool],
|
||||
):
|
||||
"""
|
||||
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
||||
@ -173,7 +174,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
resized_height, resized_width = height, width
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -191,7 +192,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
@ -249,6 +250,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
disable_grouping: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@ -323,6 +325,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
pixel_values.extend(patches)
|
||||
vision_grid_thws.append(image_grid_thw)
|
||||
@ -351,11 +354,11 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
factor = patch_size * merge_size
|
||||
resized_height, resized_width = smart_resize(
|
||||
t=self.temporal_patch_size,
|
||||
num_frames=self.temporal_patch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
temporal_factor=self.temporal_patch_size,
|
||||
factor=factor,
|
||||
t_factor=self.temporal_patch_size,
|
||||
)
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
return grid_h * grid_w
|
||||
|
@ -279,14 +279,16 @@ def eager_attention_forward(
|
||||
class Glm4vVisionAttention(nn.Module):
|
||||
def __init__(self, config: Glm4vVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = 1
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.head_dim = self.dim // self.num_heads
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
|
||||
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.config = config
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -294,23 +296,31 @@ class Glm4vVisionAttention(nn.Module):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -321,13 +331,17 @@ class Glm4vVisionAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
scaling=self.scaling,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.squeeze(0)
|
||||
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
@ -347,6 +361,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
@ -354,6 +369,7 @@ class Glm4vVisionBlock(GradientCheckpointingLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
@ -451,6 +467,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -480,14 +515,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
|
||||
)
|
||||
else:
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
@ -1016,7 +1052,7 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
@ -1046,7 +1082,6 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
|
||||
llm_pos_ids_list = []
|
||||
video_frame_num = 1
|
||||
image_index, video_index = 0, 0
|
||||
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
@ -1088,9 +1123,7 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
||||
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
|
||||
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||
|
||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
video_index += 1
|
||||
@ -1204,50 +1237,59 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = image_mask.sum()
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (video_mask).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = attention_mask
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
@ -1538,6 +1580,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
||||
def _get_image_nums_and_video_nums(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
||||
@ -1552,9 +1595,29 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
||||
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
||||
"""
|
||||
|
||||
is_image = input_ids == self.config.image_start_token_id
|
||||
is_video_start = input_ids == self.config.video_start_token_id
|
||||
is_video_end = input_ids == self.config.video_end_token_id
|
||||
if inputs_embeds is not None:
|
||||
is_image = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
is_video_start = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
is_video_end = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
else:
|
||||
is_image = input_ids == self.config.image_start_token_id
|
||||
is_video_start = input_ids == self.config.video_start_token_id
|
||||
is_video_end = input_ids == self.config.video_end_token_id
|
||||
|
||||
# Cumulative sum to track if we're inside a video span
|
||||
# We'll assume well-formed video tags (i.e. matching starts and ends)
|
||||
@ -1590,7 +1653,9 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
||||
def _expand_dict_for_generation_visual(dict_to_expand):
|
||||
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
|
||||
image_nums, video_nums = self._get_image_nums_and_video_nums(
|
||||
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
|
||||
)
|
||||
|
||||
def _repeat_interleave_samples(x, lengths, repeat_times):
|
||||
samples = torch.split(x, lengths)
|
||||
@ -1646,10 +1711,7 @@ class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin):
|
||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||
return dict_to_expand
|
||||
|
||||
# input_ids is required for expanding visual inputs
|
||||
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
|
||||
if input_ids is not None and input_ids.numel() != 0:
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
||||
|
@ -50,8 +50,8 @@ from ..qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLPreTrainedModel,
|
||||
Qwen2_5_VLRotaryEmbedding,
|
||||
Qwen2_5_VLTextModel,
|
||||
Qwen2_5_VLVisionAttention,
|
||||
Qwen2_5_VLVisionBlock,
|
||||
apply_rotary_pos_emb_vision,
|
||||
)
|
||||
from ..qwen2_5_vl.processing_qwen2_5_vl import (
|
||||
Qwen2_5_VLProcessor,
|
||||
@ -505,62 +505,13 @@ class Glm4vVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class Glm4vVisionAttention(nn.Module):
|
||||
class Glm4vVisionAttention(Qwen2_5_VLVisionAttention):
|
||||
def __init__(self, config: Glm4vVisionConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = config.num_heads
|
||||
self.head_dim = config.hidden_size // self.num_heads
|
||||
self.num_key_value_groups = 1
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
|
||||
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
query_states, key_states, value_states = (
|
||||
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||
)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
|
||||
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, _ = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.squeeze(0)
|
||||
attn_output = attn_output.reshape(seq_length, -1).contiguous()
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):
|
||||
def __init__(self, config) -> None:
|
||||
@ -652,6 +603,25 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
return rotary_pos_emb, pos_ids
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -681,14 +651,15 @@ class Glm4vVisionModel(Glm4vPreTrainedModel):
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
|
||||
)
|
||||
else:
|
||||
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.post_layernorm(hidden_states)
|
||||
|
||||
@ -1115,7 +1086,7 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
@ -1145,7 +1116,6 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
|
||||
llm_pos_ids_list = []
|
||||
video_frame_num = 1
|
||||
image_index, video_index = 0, 0
|
||||
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
@ -1187,9 +1157,7 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
||||
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
|
||||
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||
|
||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
video_index += 1
|
||||
@ -1269,50 +1237,59 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_embeds = torch.cat(image_embeds, dim=0)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = image_mask.sum()
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_embeds = torch.cat(video_embeds, dim=0)
|
||||
n_video_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (video_mask).sum()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id # GLM-4.1V use image_token_id for video
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = attention_mask
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
@ -1532,6 +1509,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
def _get_image_nums_and_video_nums(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
||||
@ -1546,9 +1524,29 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
||||
"""
|
||||
|
||||
is_image = input_ids == self.config.image_start_token_id
|
||||
is_video_start = input_ids == self.config.video_start_token_id
|
||||
is_video_end = input_ids == self.config.video_end_token_id
|
||||
if inputs_embeds is not None:
|
||||
is_image = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
is_video_start = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
is_video_end = (
|
||||
inputs_embeds
|
||||
== self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
)[..., 0]
|
||||
else:
|
||||
is_image = input_ids == self.config.image_start_token_id
|
||||
is_video_start = input_ids == self.config.video_start_token_id
|
||||
is_video_end = input_ids == self.config.video_end_token_id
|
||||
|
||||
# Cumulative sum to track if we're inside a video span
|
||||
# We'll assume well-formed video tags (i.e. matching starts and ends)
|
||||
|
@ -648,24 +648,27 @@ class GotOcr2Model(GotOcr2PreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
|
@ -339,24 +339,27 @@ class GotOcr2Model(LlavaModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype))
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
|
@ -148,7 +148,7 @@ class GraniteAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -224,7 +224,7 @@ class HeliumAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_hubert import HubertConfig
|
||||
|
||||
|
||||
@ -300,6 +301,7 @@ class HubertAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -307,7 +309,7 @@ class HubertAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -328,42 +330,9 @@ class HubertAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -385,7 +354,7 @@ class HubertAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class HubertFeedForward(nn.Module):
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from .configuration_idefics import IdeficsVisionConfig
|
||||
@ -351,6 +352,7 @@ class IdeficsVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -417,8 +419,6 @@ class IdeficsVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -933,10 +933,18 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||
"""
|
||||
special_image_token_mask = input_ids == self.image_token_id
|
||||
new_inputs_embeds = inputs_embeds.clone()
|
||||
new_inputs_embeds[special_image_token_mask] = image_hidden_states.to(new_inputs_embeds.device)
|
||||
return new_inputs_embeds
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
|
||||
return inputs_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||
"""
|
||||
@ -1041,25 +1049,8 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_seen_tokens = 0
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache:
|
||||
if not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
)
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
|
||||
if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
|
||||
raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
||||
@ -1072,7 +1063,7 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
|
||||
if image_hidden_states is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self.inputs_merger(
|
||||
@ -1094,9 +1085,6 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if return_legacy_cache and use_cache:
|
||||
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
|
||||
|
||||
return Idefics2BaseModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
@ -1304,37 +1292,11 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
# but IDEFICS requires both ids and embeds to be present
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs["input_ids"] = input_ids
|
||||
|
||||
if image_hidden_states is not None:
|
||||
if image_hidden_states is not None or cache_position[0] != 0:
|
||||
model_inputs["pixel_values"] = None
|
||||
model_inputs["pixel_attention_mask"] = None
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
||||
model_kwargs = super()._update_model_kwargs_for_generation(
|
||||
outputs=outputs,
|
||||
model_kwargs=model_kwargs,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
# Get the precomputed image_hidden_states
|
||||
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
||||
return model_kwargs
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.opt.modeling_opt.OPTForCausalLM._reorder_cache
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||||
)
|
||||
return reordered_past
|
||||
|
||||
|
||||
__all__ = ["Idefics2ForConditionalGeneration", "Idefics2PreTrainedModel", "Idefics2Model"]
|
||||
|
@ -663,15 +663,18 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
|
||||
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
|
||||
"""
|
||||
special_image_token_mask = input_ids == self.image_token_id
|
||||
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
|
||||
new_inputs_embeds = inputs_embeds.clone()
|
||||
# Flatten `image_hidden_states` if not flat yet
|
||||
image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
|
||||
# cast to the dtype of the input_embeds to support quantized models
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
new_inputs_embeds[special_image_token_mask] = image_hidden_states
|
||||
return new_inputs_embeds
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_hidden_states)
|
||||
return inputs_embeds
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
|
||||
"""
|
||||
@ -773,11 +776,8 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_seen_tokens = 0
|
||||
if use_cache:
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
past_seen_tokens = past_key_values.get_seq_length()
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
|
||||
@ -790,7 +790,7 @@ class Idefics3Model(Idefics3PreTrainedModel):
|
||||
elif image_hidden_states is not None:
|
||||
image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
|
||||
|
||||
if past_seen_tokens == 0 and input_ids is not None and image_hidden_states is not None:
|
||||
if image_hidden_states is not None:
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self.inputs_merger(
|
||||
@ -1042,28 +1042,11 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
# but IDEFICS requires both ids and embeds to be present
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs["input_ids"] = input_ids
|
||||
|
||||
if image_hidden_states is not None:
|
||||
if image_hidden_states is not None or cache_position[0] != 0:
|
||||
model_inputs["pixel_values"] = None
|
||||
model_inputs["pixel_attention_mask"] = None
|
||||
|
||||
return model_inputs
|
||||
|
||||
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation
|
||||
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
||||
model_kwargs = super()._update_model_kwargs_for_generation(
|
||||
outputs=outputs,
|
||||
model_kwargs=model_kwargs,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
# Get the precomputed image_hidden_states
|
||||
model_kwargs["image_hidden_states"] = outputs.image_hidden_states
|
||||
return model_kwargs
|
||||
|
||||
|
||||
__all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]
|
||||
|
@ -1255,6 +1255,7 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
@ -1328,12 +1329,20 @@ class InstructBlipModel(InstructBlipPreTrainedModel):
|
||||
|
||||
# step 3: use the language model, conditioned on the query outputs and the prompt
|
||||
language_model_inputs = self.language_projection(query_output)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -1513,6 +1522,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -1604,15 +1614,26 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
||||
@ -1673,6 +1694,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
@ -1690,6 +1712,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
@ -1712,23 +1736,32 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
start_tokens = [self.config.image_token_id] * self.config.num_query_tokens + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "image_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "image_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for image tokens in InstructBLIP should be done in processing. "
|
||||
|
@ -1251,6 +1251,7 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
@ -1334,12 +1335,20 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel):
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -1485,6 +1494,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -1599,15 +1609,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
@ -1668,6 +1689,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
@ -1685,6 +1707,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
@ -1708,23 +1732,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
|
@ -202,6 +202,7 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
@ -255,12 +256,20 @@ class InstructBlipVideoModel(InstructBlipModel):
|
||||
|
||||
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length
|
||||
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1)
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
else:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
|
||||
if self.config.use_decoder_only_language_model:
|
||||
outputs = self.language_model(
|
||||
@ -372,6 +381,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@ -451,15 +461,26 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
@ -520,6 +541,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
qformer_attention_mask: Optional[torch.LongTensor] = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
@ -537,6 +559,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
The sequence used as a prompt for the generation.
|
||||
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate the positional encoding of the image embeddings.
|
||||
|
||||
@ -560,23 +584,32 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera
|
||||
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
||||
)
|
||||
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
if inputs_embeds is None:
|
||||
if input_ids is None:
|
||||
start_tokens = [self.config.text_config.bos_token_id]
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
start_tokens = [self.config.video_token_id] * self.config.num_query_tokens * 4 + start_tokens
|
||||
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=language_model_inputs.device)
|
||||
input_ids = input_ids.repeat(batch_size, 1)
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# if the model already has "video_token_id" then the input is expanded to account for image embeds
|
||||
# otherwise we expand manually by concatenating
|
||||
if getattr(self.config, "video_token_id", None) is not None:
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
||||
inputs_embeds[special_image_mask] = language_model_inputs.flatten().to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Expanding inputs for video tokens in InstructBLIPVideo should be done in processing. "
|
||||
|
@ -710,14 +710,14 @@ class InternVLModel(InternVLPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -641,14 +641,14 @@ class InternVLModel(LlavaModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -1102,23 +1102,21 @@ class JanusModel(JanusPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
if input_ids is None:
|
||||
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_attention_mask = image_attention_mask.all(-1)
|
||||
else:
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
embed_dim = inputs_embeds.shape[-1]
|
||||
image_features = image_embeds.reshape(-1, embed_dim)
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||
|
||||
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
|
||||
image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
||||
|
||||
|
@ -955,23 +955,21 @@ class JanusModel(JanusPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
if input_ids is None:
|
||||
image_attention_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
image_attention_mask = image_attention_mask.all(-1)
|
||||
else:
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_embeds = self.get_image_features(pixel_values)
|
||||
image_attention_mask = input_ids == self.config.image_token_id
|
||||
|
||||
embed_dim = inputs_embeds.shape[-1]
|
||||
image_features = image_embeds.reshape(-1, embed_dim)
|
||||
image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||
|
||||
image_attention_mask = image_attention_mask.to(inputs_embeds.device)
|
||||
image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
|
||||
|
||||
|
@ -451,6 +451,7 @@ class Kosmos2VisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -517,8 +518,6 @@ class Kosmos2VisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
@ -1468,25 +1467,19 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
image_embeds_position_mask=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
cache_position=None,
|
||||
**model_kwargs,
|
||||
):
|
||||
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
||||
|
||||
# Kosmos2 has offset for position ids, so we need to create them correctly
|
||||
position_ids = create_position_ids_from_input_ids(
|
||||
input_ids,
|
||||
padding_idx=self.config.pad_token_id,
|
||||
past_key_values_length=0,
|
||||
)
|
||||
|
||||
if past_key_values is not None:
|
||||
image_embeds = None
|
||||
image_embeds_position_mask = None
|
||||
# appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation)
|
||||
elif image_embeds_position_mask is not None:
|
||||
batch_size, seq_len = input_ids.size()
|
||||
batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size()
|
||||
mask_len = image_embeds_position_mask.size()[-1]
|
||||
image_embeds_position_mask = torch.cat(
|
||||
(
|
||||
@ -1502,11 +1495,13 @@ class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
attention_mask=attention_mask,
|
||||
image_embeds=image_embeds,
|
||||
image_embeds_position_mask=image_embeds_position_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
**model_kwargs,
|
||||
)
|
||||
# Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer
|
||||
model_inputs.pop("position_ids", None)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@ -1876,6 +1871,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_embeds: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# in order to allow `inputs` argument (as in `GenerationMixin`)
|
||||
@ -1901,6 +1897,7 @@ class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
|
||||
attention_mask=attention_mask,
|
||||
image_embeds=image_embeds,
|
||||
image_embeds_position_mask=image_embeds_position_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""LayoutLM model configuration"""
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
@ -130,10 +131,22 @@ class LayoutLMConfig(PretrainedConfig):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self._position_embedding_type = position_embedding_type
|
||||
self.use_cache = use_cache
|
||||
self.max_2d_position_embeddings = max_2d_position_embeddings
|
||||
|
||||
@property
|
||||
def position_embedding_type(self):
|
||||
warnings.warn(
|
||||
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._position_embedding_type
|
||||
|
||||
@position_embedding_type.setter
|
||||
def position_embedding_type(self, value):
|
||||
self._position_embedding_type = value
|
||||
|
||||
|
||||
class LayoutLMOnnxConfig(OnnxConfig):
|
||||
def __init__(
|
||||
|
@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch LayoutLM model."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,16 +24,17 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_layoutlm import LayoutLMConfig
|
||||
|
||||
|
||||
@ -120,9 +120,37 @@ class LayoutLMEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM
|
||||
class LayoutLMSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -130,6 +158,7 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -139,20 +168,12 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -162,96 +183,33 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -270,18 +228,11 @@ class LayoutLMSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
LAYOUTLM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": LayoutLMSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM
|
||||
class LayoutLMAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = LayoutLMSelfAttention(config)
|
||||
self.output = LayoutLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -303,6 +254,9 @@ class LayoutLMAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -312,15 +266,14 @@ class LayoutLMAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -358,22 +311,19 @@ class LayoutLMOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM
|
||||
class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = LayoutLMAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = LayoutLMIntermediate(config)
|
||||
self.output = LayoutLMOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -383,60 +333,23 @@ class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -445,14 +358,19 @@ class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM
|
||||
class LayoutLMEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -465,65 +383,36 @@ class LayoutLMEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -648,6 +537,9 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -663,7 +555,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
|
||||
Bounding boxes of each input sequence tokens. Selected in the range `[0,
|
||||
@ -756,20 +648,16 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -796,6 +684,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -871,11 +762,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -889,10 +778,6 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
labels.view(-1),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
@ -921,6 +806,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -996,7 +882,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@ -1026,9 +912,6 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
@ -1059,6 +942,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1132,7 +1016,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1145,10 +1029,6 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1176,6 +1056,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1253,7 +1134,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1280,10 +1161,6 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
|
@ -224,7 +224,7 @@ class LlamaAttention(nn.Module):
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
|
@ -1358,27 +1358,28 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device)
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||
if n_image_tokens != projected_vision_flat.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"Mismatch: final_mask wants {n_image_tokens} embeddings, "
|
||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
||||
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, projected_vision_flat)
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
|
||||
projected_vision_flat = projected_vision_flat.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
|
@ -284,14 +284,14 @@ class LlavaModel(LlavaPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -468,11 +468,6 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -485,10 +480,18 @@ class LlavaNextModel(LlavaNextPreTrainedModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -519,12 +519,6 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -537,10 +531,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -559,10 +561,18 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel):
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -440,12 +440,6 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -458,10 +452,18 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -480,10 +482,18 @@ class LlavaNextVideoModel(LlavaNextModel):
|
||||
video_features = torch.cat(video_features, dim=0)
|
||||
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
|
||||
|
||||
special_image_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -551,12 +551,6 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -571,10 +565,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -595,10 +597,18 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel):
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_video_mask = special_video_mask.all(-1)
|
||||
else:
|
||||
special_video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_video_mask).sum()
|
||||
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -535,12 +535,6 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if (pixel_values is not None or pixel_values_videos is not None) and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both `pixel_values`/`pixel_values_videos` and `inputs_embeds` at the same time, "
|
||||
"and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -555,10 +549,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
@ -579,10 +581,18 @@ class LlavaOnevisionModel(LlavaNextVideoModel):
|
||||
video_features = torch.cat((video_features, image_newline), dim=1)
|
||||
video_features = video_features.flatten(0, 1)
|
||||
|
||||
special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1)
|
||||
special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_video_mask = special_video_mask.all(-1)
|
||||
else:
|
||||
special_video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
n_video_tokens = (special_video_mask).sum()
|
||||
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel():
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum()
|
||||
n_video_features = video_features.shape[0]
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
|
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""MarkupLM model configuration"""
|
||||
|
||||
import warnings
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@ -141,7 +143,7 @@ class MarkupLMConfig(PretrainedConfig):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self._position_embedding_type = position_embedding_type
|
||||
self.use_cache = use_cache
|
||||
self.classifier_dropout = classifier_dropout
|
||||
# additional properties
|
||||
@ -152,5 +154,17 @@ class MarkupLMConfig(PretrainedConfig):
|
||||
self.subs_pad_id = subs_pad_id
|
||||
self.xpath_unit_hidden_size = xpath_unit_hidden_size
|
||||
|
||||
@property
|
||||
def position_embedding_type(self):
|
||||
warnings.warn(
|
||||
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._position_embedding_type
|
||||
|
||||
@position_embedding_type.setter
|
||||
def position_embedding_type(self, value):
|
||||
self._position_embedding_type = value
|
||||
|
||||
|
||||
__all__ = ["MarkupLMConfig"]
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch MarkupLM model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -26,20 +25,22 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
ALL_ATTENTION_FUNCTIONS,
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_markuplm import MarkupLMConfig
|
||||
|
||||
|
||||
@ -326,9 +327,37 @@ class MarkupLMOnlyMLMHead(nn.Module):
|
||||
return prediction_scores
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM
|
||||
class MarkupLMSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -336,6 +365,7 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -345,20 +375,12 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -368,111 +390,41 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in MarkupLMModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
MARKUPLM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": MarkupLMSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM
|
||||
class MarkupLMAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = MarkupLMSelfAttention(config)
|
||||
self.output = MarkupLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -494,6 +446,9 @@ class MarkupLMAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -503,37 +458,33 @@ class MarkupLMAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM
|
||||
class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = MarkupLMAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = MarkupLMAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = MarkupLMIntermediate(config)
|
||||
self.output = MarkupLMOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -543,60 +494,23 @@ class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -605,14 +519,19 @@ class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM
|
||||
class MarkupLMEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([MarkupLMLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -625,65 +544,36 @@ class MarkupLMEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -749,6 +639,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -763,7 +654,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
|
||||
Tag IDs for each token in the input sequence, padded up to config.max_depth.
|
||||
@ -839,21 +730,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache
|
||||
@ -879,6 +765,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -939,7 +826,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -966,10 +853,6 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
@ -1000,6 +883,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1058,7 +942,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1072,10 +956,6 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
labels.view(-1),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=prediction_scores,
|
||||
@ -1107,6 +987,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1164,7 +1045,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@ -1194,9 +1075,6 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
|
@ -354,11 +354,6 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -367,11 +362,11 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
@ -308,11 +308,6 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -324,10 +319,18 @@ class Mistral3Model(Mistral3PreTrainedModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -204,11 +204,6 @@ class Mistral3Model(LlavaModel):
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@ -220,10 +215,18 @@ class Mistral3Model(LlavaModel):
|
||||
)
|
||||
image_features = torch.cat(image_features, dim=0)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = (special_image_mask).sum()
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
|
@ -182,7 +182,6 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen
|
||||
class MusicgenAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -189,7 +189,7 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody
|
||||
class MusicgenMelodyAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -503,7 +503,7 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->NllbMoe,key_value_states->encoder_hidden_states
|
||||
class NllbMoeAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -331,9 +331,11 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
|
@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_patchtsmixer import PatchTSMixerConfig
|
||||
|
||||
|
||||
@ -303,6 +304,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -310,7 +312,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -331,42 +333,9 @@ class PatchTSMixerAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -388,7 +357,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class PatchMixerBlock(nn.Module):
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_patchtst import PatchTSTConfig
|
||||
|
||||
|
||||
@ -100,6 +101,7 @@ class PatchTSTAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -107,7 +109,7 @@ class PatchTSTAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -128,42 +130,9 @@ class PatchTSTAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -185,7 +154,7 @@ class PatchTSTAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class PatchTSTBatchNorm(nn.Module):
|
||||
|
@ -607,6 +607,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = 0.0
|
||||
self.is_decoder = False
|
||||
self.is_causal = False
|
||||
|
||||
@ -619,6 +620,7 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
@ -634,15 +636,6 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, key_states.shape[-2]],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
@ -652,13 +645,13 @@ class Qwen2_5OmniAudioAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
@ -686,6 +679,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -704,6 +698,7 @@ class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer):
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
@ -785,6 +780,25 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
def set_input_embeddings(self, value: nn.Module):
|
||||
self.conv1 = value
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -833,9 +847,15 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
padded_mask_after_cnn.sum(1).cumsum(0),
|
||||
)
|
||||
).to(torch.int32)
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(hidden_states, cu_seqlens, **kwargs)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0)
|
||||
@ -928,12 +948,15 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.num_key_value_groups = 1 # needed for eager attention
|
||||
self.config = config
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
seq_length = hidden_states.shape[0]
|
||||
@ -943,18 +966,9 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(query_states.dtype).min,
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0) # unsqueeze batch_dim
|
||||
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
@ -966,13 +980,13 @@ class Qwen2_5OmniVisionAttention(nn.Module):
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0,
|
||||
attention_mask=attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
cu_seqlens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
|
||||
cu_seq_lens_k=cu_seqlens,
|
||||
max_length_q=max_seqlen,
|
||||
max_length_k=max_seqlen,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
@ -1009,10 +1023,15 @@ class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, **kwargs
|
||||
self.norm1(hidden_states),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
@ -1171,6 +1190,25 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
|
||||
# NOTE: the created attention masl only approximates the ragged FA2 attention by
|
||||
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
|
||||
# blocks. Though it will not be a 100% match for FA2's `varlen` path
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
return None
|
||||
|
||||
seq_length = inputs_tensor.shape[0]
|
||||
attention_mask = torch.full(
|
||||
[1, 1, seq_length, seq_length],
|
||||
torch.finfo(inputs_tensor.dtype).min,
|
||||
device=inputs_tensor.device,
|
||||
dtype=inputs_tensor.dtype,
|
||||
)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
||||
return attention_mask
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -1217,10 +1255,13 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens_now)
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = self.merger(hidden_states)
|
||||
@ -1862,43 +1903,51 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text , audios , image and video
|
||||
if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage
|
||||
if input_features is not None:
|
||||
audio_features = self.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
if input_features is not None:
|
||||
audio_features = self.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
)
|
||||
if input_ids is None:
|
||||
audio_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
audio_mask = (
|
||||
(input_ids == self.config.audio_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
audio_mask = audio_mask.all(-1)
|
||||
else:
|
||||
audio_mask = input_ids == self.config.audio_token_id
|
||||
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
||||
if input_ids is None:
|
||||
image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
image_mask = image_mask.all(-1)
|
||||
else:
|
||||
image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
||||
if input_ids is None:
|
||||
video_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
video_mask = video_mask.all(-1)
|
||||
else:
|
||||
video_mask = input_ids == self.config.video_token_id
|
||||
|
||||
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user