mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 09:42:22 +06:00
Merge branch 'main' into vas-bert-attn-refactors
This commit is contained in:
commit
b82b47e5d5
@ -288,7 +288,7 @@ Keywords: Music understanding, Music generation
|
||||
|
||||
## [dalle-flow](https://github.com/jina-ai/dalle-flow)
|
||||
|
||||
DALL·E Flow is an interactive workflow for generating high-definition images from a text prompt. Itt leverages DALL·E-Mega, GLID-3 XL, and Stable Diffusion to generate image candidates, and then calls CLIP-as-service to rank the candidates w.r.t. the prompt.
|
||||
DALL·E Flow is an interactive workflow for generating high-definition images from a text prompt. It leverages DALL·E-Mega, GLID-3 XL, and Stable Diffusion to generate image candidates, and then calls CLIP-as-service to rank the candidates w.r.t. the prompt.
|
||||
The preferred candidate is fed to GLID-3 XL for diffusion, which often enriches the texture and background. Finally, the candidate is upscaled to 1024x1024 via SwinIR.
|
||||
|
||||
Keywords: High-definition image generation, Stable Diffusion, DALL-E Mega, GLID-3 XL, CLIP, SwinIR
|
||||
@ -526,7 +526,7 @@ Keywords: Model deployment, CLoud, Mobile, Edge
|
||||
|
||||
## [underthesea](https://github.com/undertheseanlp/underthesea)
|
||||
|
||||
[underthesea](https://github.com/undertheseanlp/underthesea) is a Vietnamese NLP toolkit. Underthesea is a suite of open source Python modules data sets and tutorials supporting research and development in Vietnamese Natural Language Processing. We provides extremely easy API to quickly apply pretrained NLP models to your Vietnamese text, such as word segmentation, part-of-speech tagging (PoS), named entity recognition (NER), text classification and dependency parsing.
|
||||
[underthesea](https://github.com/undertheseanlp/underthesea) is a Vietnamese NLP toolkit. Underthesea is a suite of open source Python modules data sets and tutorials supporting research and development in Vietnamese Natural Language Processing. We provide extremely easy API to quickly apply pretrained NLP models to your Vietnamese text, such as word segmentation, part-of-speech tagging (PoS), named entity recognition (NER), text classification and dependency parsing.
|
||||
|
||||
Keywords: Vietnamese, NLP
|
||||
|
||||
|
@ -56,7 +56,7 @@ Create a [`ImageTextToTextPipeline`] and pass the chat to it. For large models,
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda", torch_dtype=torch.float16)
|
||||
pipeline = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device_map="auto", torch_dtype=torch.float16)
|
||||
pipeline(text=messages, max_new_tokens=50, return_full_text=False)
|
||||
[{'input_text': [{'role': 'system',
|
||||
'content': [{'type': 'text',
|
||||
@ -175,7 +175,7 @@ processed_chat = processor.apply_chat_template(
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=32,
|
||||
video_fps=16,
|
||||
video_load_backend="decord",
|
||||
)
|
||||
print(processed_chat.keys())
|
||||
|
@ -26,6 +26,7 @@ Pass the audio signal, typically stored in `array`, to the feature extractor and
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train")
|
||||
processed_sample = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
|
||||
processed_sample
|
||||
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
|
||||
|
@ -14,59 +14,123 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# BigBirdPegasus
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# BigBirdPegasus
|
||||
|
||||
The BigBird model was proposed in [Big Bird: Transformers for Longer Sequences](https://huggingface.co/papers/2007.14062) by
|
||||
Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon,
|
||||
Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention
|
||||
based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse
|
||||
attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it
|
||||
has been shown that applying sparse, global, and random attention approximates full attention, while being
|
||||
computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context,
|
||||
BigBird has shown improved performance on various long document NLP tasks, such as question answering and
|
||||
summarization, compared to BERT or RoBERTa.
|
||||
[BigBirdPegasus](https://huggingface.co/papers/2007.14062) is an encoder-decoder (sequence-to-sequence) transformer model for long-input summarization. It extends the [BigBird](./big_bird) architecture with an additional pretraining objective borrowed from [Pegasus](./pegasus) called gap sequence generation (GSG). Whole sentences are masked and the model has to fill in the gaps in the document. BigBirdPegasus's ability to keep track of long contexts makes it effective at summarizing lengthy inputs, surpassing the performance of base Pegasus models.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original BigBirdPegasus checkpoints under the [Google](https://huggingface.co/google/models?search=bigbird-pegasus) organization.
|
||||
|
||||
*Transformers-based models, such as BERT, have been one of the most successful deep learning models for NLP.
|
||||
Unfortunately, one of their core limitations is the quadratic dependency (mainly in terms of memory) on the sequence
|
||||
length due to their full attention mechanism. To remedy this, we propose, BigBird, a sparse attention mechanism that
|
||||
reduces this quadratic dependency to linear. We show that BigBird is a universal approximator of sequence functions and
|
||||
is Turing complete, thereby preserving these properties of the quadratic, full attention model. Along the way, our
|
||||
theoretical analysis reveals some of the benefits of having O(1) global tokens (such as CLS), that attend to the entire
|
||||
sequence as part of the sparse attention mechanism. The proposed sparse attention can handle sequences of length up to
|
||||
8x of what was previously possible using similar hardware. As a consequence of the capability to handle longer context,
|
||||
BigBird drastically improves performance on various NLP tasks such as question answering and summarization. We also
|
||||
propose novel applications to genomics data.*
|
||||
> [!TIP]
|
||||
> This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta).
|
||||
>
|
||||
> Click on the BigBirdPegasus models in the right sidebar for more examples of how to apply BigBirdPegasus to different language tasks.
|
||||
|
||||
The original code can be found [here](https://github.com/google-research/bigbird).
|
||||
The example below demonstrates how to summarize text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
- For an in-detail explanation on how BigBird's attention works, see [this blog post](https://huggingface.co/blog/big-bird).
|
||||
- BigBird comes with 2 implementations: **original_full** & **block_sparse**. For the sequence length < 1024, using
|
||||
**original_full** is advised as there is no benefit in using **block_sparse** attention.
|
||||
- The code currently uses window size of 3 blocks and 2 global blocks.
|
||||
- Sequence length must be divisible by block size.
|
||||
- Current implementation supports only **ITC**.
|
||||
- Current implementation doesn't support **num_random_blocks = 0**.
|
||||
- BigBirdPegasus uses the [PegasusTokenizer](https://github.com/huggingface/transformers/blob/main/src/transformers/models/pegasus/tokenization_pegasus.py).
|
||||
- BigBird is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="summarization",
|
||||
model="google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.float32,
|
||||
device=0
|
||||
)
|
||||
pipeline("""Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle.""")
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
input_text = """Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle."""
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet. Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts." | transformers-cli run --task summarization --model google/bigbird-pegasus-large-arxiv --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv"
|
||||
)
|
||||
|
||||
input_text = """Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle."""
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- BigBirdPegasus also uses the [`PegasusTokenizer`].
|
||||
- Inputs should be padded on the right because BigBird uses absolute position embeddings.
|
||||
- BigBirdPegasus supports `original_full` and `block_sparse` attention. If the input sequence length is less than 1024, it is recommended to use `original_full` since sparse patterns don't offer much benefit for smaller inputs.
|
||||
- The current implementation uses window size of 3 blocks and 2 global blocks, only supports the ITC-implementation, and doesn't support `num_random_blocks=0`.
|
||||
- The sequence length must be divisible by the block size.
|
||||
|
||||
## Resources
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
- [Translation task guide](../tasks/translation)
|
||||
- [Summarization task guide](../tasks/summarization)
|
||||
Read the [Understanding BigBird's Block Sparse Attention](https://huggingface.co/blog/big-bird) blog post for more details about how BigBird's attention works.
|
||||
|
||||
## BigBirdPegasusConfig
|
||||
|
||||
|
@ -15,9 +15,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Distributed inference
|
||||
|
||||
When a model doesn't fit on a single GPU, distributed inference with [tensor parallelism](./perf_train_gpu_many#tensor-parallelism) can help. Tensor parallelism shards a model onto multiple GPUs and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each GPU can process a tensor slice.
|
||||
When a model doesn't fit on a single GPU, distributed inference with [tensor parallelism](./perf_train_gpu_many#tensor-parallelism) can help. Tensor parallelism shards a model onto multiple accelerators (CUDA GPU, Intel XPU, etc.) and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each accelerator can process a tensor slice.
|
||||
|
||||
However, tensor parallelism adds communication overhead and should be used on single machine setups with multiple GPUs to take advantage of fast intra-node communication. For multi-node training, it may be more efficient to use pipeline or data parallelism depending on your use case.
|
||||
However, tensor parallelism adds communication overhead and should be used on single machine setups with multiple accelerators to take advantage of fast intra-node communication. For multi-node training, it may be more efficient to use pipeline or data parallelism depending on your use case.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) section on tensor parallelism to learn more.
|
||||
@ -308,4 +308,4 @@ The most important part of DTensor is the `placement` attribute because it tells
|
||||
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
|
||||
```
|
||||
|
||||
- `Partial()` - Indicates a tensor is pending a reduction operation (not typically relevant for usage in Transformers).
|
||||
- `Partial()` - Indicates a tensor is pending a reduction operation (not typically relevant for usage in Transformers).
|
||||
|
@ -47,7 +47,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device.type)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
@ -49,6 +49,7 @@ Check the table below to see if your hardware is compatible.
|
||||
| Component | Compatibility |
|
||||
|----------|----------------|
|
||||
| CUDA Versions | ✅ cu118, cu126, cu128 |
|
||||
| XPU Versions | ✅ pytorch2.8 |
|
||||
| CPU | ✅ change `device_map="cpu"` (see examples below) |
|
||||
|
||||
|
||||
@ -278,6 +279,71 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Intel XPU
|
||||
<hfoptions id="examples-Intel-XPU">
|
||||
<hfoption id="int8-dynamic-and-weight-only">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig
|
||||
|
||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||
# or int8 weight only quantization
|
||||
# quant_config = Int8WeightOnlyConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("xpu")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="int4-weight-only">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
from torchao.dtypes import Int4XPULayout
|
||||
from torchao.quantization.quant_primitives import ZeroPointDomain
|
||||
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("xpu")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### CPU
|
||||
<hfoptions id="examples-CPU">
|
||||
<hfoption id="int8-dynamic-and-weight-only">
|
||||
@ -363,7 +429,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Manual Testing
|
||||
prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(quantized_model.device.type)
|
||||
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
|
||||
output_text = tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
@ -434,7 +500,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device.type)
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
@ -474,7 +540,7 @@ tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
|
||||
|
||||
## Loading quantized models
|
||||
|
||||
Loading a quantized model depends on the quantization scheme. For quantization schemes, like int8 and float8, you can quantize the model on any device and also load it on any device. The example below demonstrates quantizing a model on the CPU and then loading it on CUDA.
|
||||
Loading a quantized model depends on the quantization scheme. For quantization schemes, like int8 and float8, you can quantize the model on any device and also load it on any device. The example below demonstrates quantizing a model on the CPU and then loading it on CUDA or XPU.
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@ -491,7 +557,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
# save the quantized model
|
||||
output_dir = "llama-3.1-8b-torchao-int8-cuda"
|
||||
output_dir = "llama-3.1-8b-torchao-int8"
|
||||
quantized_model.save_pretrained(output_dir, safe_serialization=False)
|
||||
|
||||
# reload the quantized model
|
||||
@ -502,7 +568,7 @@ reloaded_model = AutoModelForCausalLM.from_pretrained(
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(reloaded_model.device.type)
|
||||
|
||||
output = reloaded_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
@ -57,10 +57,12 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "ccl", "hpu": "hccl"}
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
|
||||
backend = backend_map.get(device_type)
|
||||
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", 0)):
|
||||
backend = "ccl"
|
||||
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
|
||||
backend = "ccl"
|
||||
|
||||
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
current_device = getattr(torch, device_type)
|
||||
|
@ -3746,7 +3746,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
module_map[name + f".{key}"] = module
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
|
||||
if any(
|
||||
allowed_name in class_name.__name__.lower()
|
||||
for class_name in self.__class__.__mro__[:-1]
|
||||
for allowed_name in VLMS
|
||||
):
|
||||
reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
|
||||
|
||||
original_state_dict = {}
|
||||
@ -4402,7 +4406,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
key_mapping = kwargs.pop("key_mapping", None)
|
||||
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
|
||||
if key_mapping is None and any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
|
||||
if key_mapping is None and any(
|
||||
allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
|
||||
):
|
||||
key_mapping = cls._checkpoint_conversion_mapping
|
||||
|
||||
# Not used anymore -- remove them from the kwargs
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,14 +25,15 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithNoAttention,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndNoAttention,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig
|
||||
|
||||
|
||||
@ -90,7 +91,7 @@ class AlignOutput(ModelOutput):
|
||||
The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
|
||||
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
|
||||
The output of [`AlignVisionModel`].
|
||||
text_model_output (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
||||
text_model_output (`BaseModelOutputWithPooling`):
|
||||
The output of the [`AlignTextModel`].
|
||||
vision_model_output (`BaseModelOutputWithPoolingAndNoAttention`):
|
||||
The output of the [`AlignVisionModel`].
|
||||
@ -101,7 +102,7 @@ class AlignOutput(ModelOutput):
|
||||
logits_per_text: Optional[torch.FloatTensor] = None
|
||||
text_embeds: Optional[torch.FloatTensor] = None
|
||||
image_embeds: Optional[torch.FloatTensor] = None
|
||||
text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
|
||||
text_model_output: BaseModelOutputWithPooling = None
|
||||
vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None
|
||||
|
||||
def to_tuple(self) -> tuple[Any]:
|
||||
@ -508,7 +509,6 @@ class AlignVisionEncoder(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->AlignText
|
||||
class AlignTextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
@ -537,7 +537,6 @@ class AlignTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -547,7 +546,7 @@ class AlignTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
@ -573,9 +572,35 @@ class AlignTextEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->AlignText
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AlignTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -583,6 +608,7 @@ class AlignTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -592,20 +618,12 @@ class AlignTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -615,96 +633,33 @@ class AlignTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in AlignTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -723,18 +678,10 @@ class AlignTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": AlignTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
|
||||
class AlignTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = AlignTextSelfAttention(config)
|
||||
self.output = AlignTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -756,6 +703,9 @@ class AlignTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -765,15 +715,14 @@ class AlignTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -811,22 +760,18 @@ class AlignTextOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
|
||||
class AlignTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = AlignTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = AlignTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = AlignTextIntermediate(config)
|
||||
self.output = AlignTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -836,60 +781,23 @@ class AlignTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -898,14 +806,18 @@ class AlignTextLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->AlignText
|
||||
class AlignTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([AlignTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -918,65 +830,36 @@ class AlignTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1052,6 +935,7 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1059,12 +943,13 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Examples:
|
||||
|
||||
@ -1133,20 +1018,17 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1180,6 +1062,7 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.convolution
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1219,7 +1102,7 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
# Apply pooling
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
@ -1227,9 +1110,6 @@ class AlignVisionModel(AlignPreTrainedModel):
|
||||
# Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim)
|
||||
pooled_output = pooled_output.reshape(pooled_output.shape[:2])
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndNoAttention(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1369,6 +1249,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1419,7 +1300,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1431,7 +1312,7 @@ class AlignModel(AlignPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
@ -1450,10 +1331,6 @@ class AlignModel(AlignPreTrainedModel):
|
||||
if return_loss:
|
||||
loss = align_loss(logits_per_text)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return AlignOutput(
|
||||
loss=loss,
|
||||
logits_per_image=logits_per_image,
|
||||
|
@ -26,14 +26,14 @@ from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndProjection,
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig
|
||||
|
||||
|
||||
@ -180,7 +180,6 @@ class AltRobertaEmbeddings(nn.Module):
|
||||
return position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->AltRoberta
|
||||
class AltRobertaSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
@ -206,13 +205,9 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -223,55 +218,19 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
@ -310,8 +269,6 @@ class AltRobertaSelfAttention(nn.Module):
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -335,7 +292,6 @@ ALT_ROBERTA_SELF_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaAttention with Roberta->AltRoberta,ROBERTA->ALT_ROBERTA
|
||||
class AltRobertaAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
@ -363,6 +319,9 @@ class AltRobertaAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -375,12 +334,9 @@ class AltRobertaAttention(nn.Module):
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -418,22 +374,19 @@ class AltRobertaOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->AltRoberta
|
||||
class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = AltRobertaAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = AltRobertaAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = AltRobertaIntermediate(config)
|
||||
self.output = AltRobertaOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -443,60 +396,23 @@ class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -505,14 +421,19 @@ class AltRobertaLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaEncoder with Roberta->AltRoberta
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->AltRoberta
|
||||
class AltRobertaEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([AltRobertaLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -525,65 +446,36 @@ class AltRobertaEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -787,6 +679,7 @@ class AltCLIPEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -853,8 +746,6 @@ class AltCLIPEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
@ -1008,6 +899,7 @@ class AltCLIPVisionTransformer(nn.Module):
|
||||
self.encoder = AltCLIPEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1033,16 +925,13 @@ class AltCLIPVisionTransformer(nn.Module):
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1106,16 +995,11 @@ class AltCLIPVisionModel(AltCLIPPreTrainedModel):
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in *Attention is
|
||||
The model behaves as an encoder following the architecture described in *Attention is
|
||||
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||
Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
|
||||
.. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
|
||||
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
||||
"""
|
||||
)
|
||||
class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
@ -1152,6 +1036,10 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
@ -1176,11 +1064,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1194,11 +1077,8 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
@ -1212,21 +1092,6 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
@ -1235,33 +1100,23 @@ class AltRobertaModel(AltCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1284,6 +1139,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
|
||||
return super().resize_token_embeddings(new_num_tokens)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1326,11 +1184,9 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# last module outputs
|
||||
@ -1343,9 +1199,6 @@ class AltCLIPTextModel(AltCLIPPreTrainedModel):
|
||||
projection_state = self.transformation(sequence_output)
|
||||
pooler_output = projection_state[:, 0]
|
||||
|
||||
if not return_dict:
|
||||
return (projection_state, pooler_output) + outputs[2:4]
|
||||
|
||||
return BaseModelOutputWithPoolingAndProjection(
|
||||
last_hidden_state=projection_state,
|
||||
pooler_output=pooler_output,
|
||||
|
@ -1026,7 +1026,6 @@ class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
@ -26,13 +26,14 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_bros import BrosConfig
|
||||
|
||||
|
||||
@ -150,7 +151,6 @@ class BrosTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -160,7 +160,7 @@ class BrosTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self, "token_type_ids"):
|
||||
@ -208,14 +208,7 @@ class BrosSelfAttention(nn.Module):
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor):
|
||||
new_x_shape = x.size()[:-1] + (
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -227,42 +220,21 @@ class BrosSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[torch.Tensor] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
if is_cross_attention:
|
||||
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
@ -317,7 +289,7 @@ class BrosSelfAttention(nn.Module):
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -364,6 +336,7 @@ class BrosAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -382,7 +355,6 @@ class BrosAttention(nn.Module):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
@ -435,6 +407,7 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
self.intermediate = BrosIntermediate(config)
|
||||
self.output = BrosOutput(config)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -446,50 +419,38 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
bbox_pos_emb=bbox_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if hasattr(self, "crossattention"):
|
||||
raise Exception(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
@ -500,7 +461,7 @@ class BrosLayer(GradientCheckpointingLayer):
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -516,6 +477,9 @@ class BrosEncoder(nn.Module):
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -529,33 +493,28 @@ class BrosEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
bbox_pos_emb,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
hidden_states=hidden_states,
|
||||
bbox_pos_emb=bbox_pos_emb,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
@ -564,21 +523,8 @@ class BrosEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
@ -689,6 +635,9 @@ class BrosModel(BrosPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -736,11 +685,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -756,9 +700,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
|
||||
@ -797,7 +738,6 @@ class BrosModel(BrosPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
|
||||
@ -813,22 +753,16 @@ class BrosModel(BrosPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
@ -852,6 +786,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -908,7 +843,7 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -927,10 +862,6 @@ class BrosForTokenClassification(BrosPreTrainedModel):
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -976,6 +907,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1037,7 +969,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_states = outputs[0]
|
||||
@ -1082,10 +1014,6 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
loss = initial_token_loss + subsequent_token_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (initial_token_logits, subsequent_token_logits) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return BrosSpadeOutput(
|
||||
loss=loss,
|
||||
initial_token_logits=initial_token_logits,
|
||||
@ -1118,6 +1046,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1173,7 +1102,7 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_states = outputs[0]
|
||||
@ -1203,10 +1132,6 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
|
||||
|
||||
loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Chinese-CLIP model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -26,13 +25,13 @@ from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_chinese_clip import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
|
||||
|
||||
|
||||
@ -90,7 +89,7 @@ class ChineseCLIPOutput(ModelOutput):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEmbeddings with Align->ChineseCLIP
|
||||
class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
@ -119,7 +118,6 @@ class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -129,7 +127,7 @@ class ChineseCLIPTextEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
@ -239,9 +237,37 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->ChineseCLIP
|
||||
class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -249,6 +275,7 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -258,20 +285,12 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -281,96 +300,33 @@ class ChineseCLIPTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in ChineseCLIPTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -389,18 +345,11 @@ class ChineseCLIPTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ChineseCLIPTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ChineseCLIPText,BERT->CHINESE_CLIP_TEXT
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->ChineseCLIP
|
||||
class ChineseCLIPTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = CHINESE_CLIP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = ChineseCLIPTextSelfAttention(config)
|
||||
self.output = ChineseCLIPTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -422,6 +371,9 @@ class ChineseCLIPTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -431,15 +383,14 @@ class ChineseCLIPTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -468,66 +419,37 @@ class ChineseCLIPVisionAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, **kwargs
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) * self.scale
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
None,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=1.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->ChineseCLIPText
|
||||
@ -577,22 +499,19 @@ class ChineseCLIPVisionMLP(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->ChineseCLIP
|
||||
class ChineseCLIPTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ChineseCLIPTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = ChineseCLIPTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = ChineseCLIPTextIntermediate(config)
|
||||
self.output = ChineseCLIPTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -602,60 +521,23 @@ class ChineseCLIPTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -777,14 +659,19 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ChineseCLIPText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP
|
||||
class ChineseCLIPTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([ChineseCLIPTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -797,65 +684,36 @@ class ChineseCLIPTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -874,6 +732,7 @@ class ChineseCLIPVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([ChineseCLIPVisionLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -922,8 +781,6 @@ class ChineseCLIPVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
@ -940,6 +797,7 @@ class ChineseCLIPVisionTransformer(nn.Module):
|
||||
self.encoder = ChineseCLIPVisionEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -965,16 +823,13 @@ class ChineseCLIPVisionTransformer(nn.Module):
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
@ -1034,6 +889,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1050,18 +906,13 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1093,56 +944,28 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1343,6 +1166,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
|
||||
return image_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1392,7 +1216,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1402,7 +1226,7 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
image_embeds = vision_outputs[1]
|
||||
@ -1424,14 +1248,6 @@ class ChineseCLIPModel(ChineseCLIPPreTrainedModel):
|
||||
if return_loss:
|
||||
loss = chinese_clip_loss(logits_per_text)
|
||||
|
||||
if not return_dict:
|
||||
# fix the None pooled_output of text_outputs to conform with dict_output
|
||||
pooled_output = text_outputs[1]
|
||||
if pooled_output is None:
|
||||
text_outputs = (text_outputs[0],) + text_outputs[2:]
|
||||
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ChineseCLIPOutput(
|
||||
loss=loss,
|
||||
logits_per_image=logits_per_image,
|
||||
|
@ -17,7 +17,7 @@
|
||||
import collections
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -26,13 +26,14 @@ from torch import nn
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig
|
||||
|
||||
|
||||
@ -399,11 +400,6 @@ class ClapAudioSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -412,11 +408,11 @@ class ClapAudioSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
@ -1090,9 +1086,37 @@ class ClapTextEmbeddings(nn.Module):
|
||||
return position_ids.unsqueeze(0).expand(input_shape)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with Align->Clap
|
||||
class ClapTextSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -1100,6 +1124,7 @@ class ClapTextSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -1109,20 +1134,12 @@ class ClapTextSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1132,96 +1149,33 @@ class ClapTextSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in ClapTextModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -1240,18 +1194,11 @@ class ClapTextSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
CLAP_TEXT_SELF_ATTENTION_CLASSES = {
|
||||
"eager": ClapTextSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->ClapText,BERT->CLAP_TEXT
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with Align->Clap
|
||||
class ClapTextAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = CLAP_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = ClapTextSelfAttention(config)
|
||||
self.output = ClapTextSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -1273,6 +1220,9 @@ class ClapTextAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1282,15 +1232,14 @@ class ClapTextAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -1328,22 +1277,19 @@ class ClapTextOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with Align->Clap
|
||||
class ClapTextLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = ClapTextAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = ClapTextAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = ClapTextIntermediate(config)
|
||||
self.output = ClapTextOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1353,60 +1299,23 @@ class ClapTextLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -1415,14 +1324,19 @@ class ClapTextLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->ClapText
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->Clap
|
||||
class ClapTextEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([ClapTextLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([ClapTextLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1435,65 +1349,36 @@ class ClapTextEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1643,6 +1528,11 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1666,11 +1556,6 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -1684,11 +1569,8 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
@ -1702,21 +1584,6 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
@ -1725,33 +1592,23 @@ class ClapTextModel(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -1892,6 +1749,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
|
||||
return audio_features
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1947,7 +1805,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
is_longer=is_longer,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
text_outputs = self.text_model(
|
||||
@ -1956,7 +1814,7 @@ class ClapModel(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
audio_embeds = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
|
||||
@ -1981,10 +1839,6 @@ class ClapModel(ClapPreTrainedModel):
|
||||
audio_loss = contrastive_loss(logits_per_audio.t())
|
||||
loss = (caption_loss + audio_loss) / 2.0
|
||||
|
||||
if not return_dict:
|
||||
output = (logits_per_audio, logits_per_text, text_embeds, audio_embeds, text_outputs, audio_outputs)
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return ClapOutput(
|
||||
loss=loss,
|
||||
logits_per_audio=logits_per_audio,
|
||||
@ -2013,6 +1867,7 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.text_model.embeddings.word_embeddings = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2045,17 +1900,13 @@ class ClapTextModelWithProjection(ClapPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output
|
||||
|
||||
text_embeds = self.text_projection(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
return ClapTextModelOutput(
|
||||
text_embeds=text_embeds,
|
||||
last_hidden_state=text_outputs.last_hidden_state,
|
||||
@ -2079,6 +1930,7 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.audio_model.audio_encoder.patch_embed.proj
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -2123,17 +1975,13 @@ class ClapAudioModelWithProjection(ClapPreTrainedModel):
|
||||
is_longer=is_longer,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = audio_outputs[1] if not return_dict else audio_outputs.pooler_output
|
||||
|
||||
audio_embeds = self.audio_projection(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
outputs = (audio_embeds, audio_outputs[0]) + audio_outputs[2:]
|
||||
return tuple(output for output in outputs if output is not None)
|
||||
|
||||
return ClapAudioModelOutput(
|
||||
audio_embeds=audio_embeds,
|
||||
last_hidden_state=audio_outputs.last_hidden_state,
|
||||
|
@ -28,7 +28,7 @@ from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int
|
||||
from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
|
||||
|
||||
|
||||
@ -490,6 +490,7 @@ class CLIPSegEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -555,8 +556,6 @@ class CLIPSegEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -45,6 +45,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_data2vec_audio import Data2VecAudioConfig
|
||||
|
||||
|
||||
@ -240,6 +241,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -247,7 +249,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -268,42 +270,9 @@ class Data2VecAudioAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -325,7 +294,7 @@ class Data2VecAudioAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class Data2VecAudioFeedForward(nn.Module):
|
||||
|
@ -634,7 +634,6 @@ class Data2VecTextModel(Data2VecTextPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@auto_docstring
|
||||
# Copied from transformers.models.clap.modeling_clap.ClapTextModel.forward
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
|
@ -405,11 +405,6 @@ class DonutSwinSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -418,11 +413,11 @@ class DonutSwinSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
@ -26,14 +26,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_esm import EsmConfig
|
||||
|
||||
|
||||
@ -187,12 +188,16 @@ class EsmEmbeddings(nn.Module):
|
||||
self.mask_token_id = config.mask_token_id
|
||||
|
||||
def forward(
|
||||
self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
):
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
|
||||
@ -281,11 +286,7 @@ class EsmSelfAttention(nn.Module):
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -296,32 +297,22 @@ class EsmSelfAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
|
||||
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
if is_cross_attention:
|
||||
key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
|
||||
# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
|
||||
@ -329,16 +320,6 @@ class EsmSelfAttention(nn.Module):
|
||||
# ESM code and fix rotary embeddings.
|
||||
query_layer = query_layer * self.attention_head_size**-0.5
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
if self.position_embedding_type == "rotary":
|
||||
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
||||
|
||||
@ -385,7 +366,7 @@ class EsmSelfAttention(nn.Module):
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -418,6 +399,7 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -441,7 +423,6 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
@ -450,9 +431,6 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
if past_key_value is not None:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
@ -514,7 +492,7 @@ class EsmFlashAttention2(EsmSelfAttention):
|
||||
|
||||
outputs = (attn_output, None)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
@ -551,6 +529,7 @@ class EsmAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -564,12 +543,11 @@ class EsmAttention(nn.Module):
|
||||
hidden_states_ln = self.LayerNorm(hidden_states)
|
||||
self_outputs = self.self(
|
||||
hidden_states_ln,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -616,6 +594,7 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
self.output = EsmOutput(config)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -626,25 +605,20 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise AttributeError(
|
||||
@ -652,31 +626,24 @@ class EsmLayer(GradientCheckpointingLayer):
|
||||
" with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
layer_output = self.feed_forward_chunk(attention_output)
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
outputs = outputs + (None,)
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -694,6 +661,9 @@ class EsmEncoder(nn.Module):
|
||||
self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -707,38 +677,26 @@ class EsmEncoder(nn.Module):
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
|
||||
"`use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
@ -750,21 +708,8 @@ class EsmEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutputWithCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
@ -863,6 +808,9 @@ class EsmModel(EsmPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -903,11 +851,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -921,11 +864,8 @@ class EsmModel(EsmPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
extended_attention_mask = attention_mask
|
||||
@ -958,7 +898,6 @@ class EsmModel(EsmPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -966,22 +905,16 @@ class EsmModel(EsmPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
@ -1025,6 +958,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1058,7 +992,7 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
@ -1070,10 +1004,6 @@ class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
labels = labels.to(prediction_scores.device)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
@ -1125,6 +1055,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1154,7 +1085,7 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
@ -1184,10 +1115,6 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1210,6 +1137,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1237,7 +1165,7 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1252,10 +1180,6 @@ class EsmForTokenClassification(EsmPreTrainedModel):
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1283,7 +1207,7 @@ class EsmClassificationHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||
@ -1295,7 +1219,7 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
|
||||
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
|
@ -1135,9 +1135,17 @@ class Gemma3nTextAltUp(nn.Module):
|
||||
corrected += predictions # add the original input
|
||||
return corrected.contiguous().type_as(activated)
|
||||
|
||||
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
|
||||
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
|
||||
`scale_corrected_output`
|
||||
"""
|
||||
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
||||
|
||||
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
|
||||
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
||||
return self.forward(corrected)
|
||||
|
||||
|
||||
class Gemma3nTextRotaryEmbedding(nn.Module):
|
||||
@ -1290,7 +1298,7 @@ class Gemma3nTextAttention(nn.Module):
|
||||
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
|
||||
|
||||
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
|
||||
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
|
||||
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
||||
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
|
||||
layer_type = config.layer_types[layer_idx]
|
||||
self.kv_shared_layer_index = (
|
||||
@ -1319,21 +1327,22 @@ class Gemma3nTextAttention(nn.Module):
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
|
||||
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
|
||||
# Device of past layer may be different from current one
|
||||
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
|
||||
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
|
||||
if isinstance(past_key_value, HybridCache) and self.is_sliding:
|
||||
max_length = past_key_value.sliding_window
|
||||
if cache_position.shape[0] > max_length:
|
||||
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
|
||||
# slice into the entire cache.
|
||||
indices = slice(0, max_length)
|
||||
else:
|
||||
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
|
||||
indices = cache_position.clamp(min=0, max=max_length - 1)
|
||||
else:
|
||||
indices = cache_position
|
||||
indices = (
|
||||
slice(0, max_length)
|
||||
if cache_position.shape[0] > max_length
|
||||
else cache_position.clamp(min=0, max=max_length - 1)
|
||||
)
|
||||
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
|
||||
query_states.device
|
||||
)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
key_states = self.k_norm(key_states)
|
||||
@ -1447,10 +1456,9 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
|
||||
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
|
||||
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
|
||||
|
||||
first_prediction = corrected_predictions[self.config.altup_active_idx]
|
||||
first_prediction_clone = first_prediction.clone()
|
||||
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
|
||||
if self.config.altup_correct_scale:
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
||||
|
||||
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
|
||||
first_prediction = self.per_layer_input_gate(first_prediction)
|
||||
@ -1475,7 +1483,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
|
||||
config_class = Gemma3nConfig
|
||||
base_model_prefix = ""
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Gemma3nDecoderLayer"]
|
||||
_no_split_modules = ["Gemma3nTextDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
@ -1656,18 +1664,17 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
|
||||
|
||||
# Expand hidden_states to support per-layer inputs
|
||||
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
|
||||
epsilon_tensor = torch.tensor(torch.finfo().min)
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
|
||||
epsilon_tensor = torch.tensor(1e-5)
|
||||
|
||||
temp_hidden_states = [hidden_states_0]
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
|
||||
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
||||
current_hidden_state = current_hidden_state * (
|
||||
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
||||
)
|
||||
altup_proj = self.altup_projections[i - 1](hidden_states_0)
|
||||
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
||||
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
|
||||
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
|
||||
temp_hidden_states.append(current_hidden_state)
|
||||
|
||||
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
|
||||
@ -1685,9 +1692,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
per_layer_input=per_layer_input,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
per_layer_input,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
@ -1712,11 +1719,10 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
|
||||
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
||||
current_hidden_state = current_hidden_state * (
|
||||
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
||||
)
|
||||
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
||||
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
|
||||
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
|
||||
temp_hidden_states.append(current_hidden_state)
|
||||
|
||||
hidden_states = torch.stack(temp_hidden_states)
|
||||
@ -1743,7 +1749,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
|
||||
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
|
||||
per_layer_projection *= self.per_layer_projection_scale.to(
|
||||
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
||||
)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*inputs_embeds.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
@ -1758,7 +1766,9 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
|
||||
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
|
||||
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
|
||||
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
|
||||
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.")
|
||||
|
@ -1685,9 +1685,17 @@ class Gemma3nTextAltUp(nn.Module):
|
||||
corrected += predictions # add the original input
|
||||
return corrected.contiguous().type_as(activated)
|
||||
|
||||
def forward(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This is only defined as the `forward` so that accelerate hooks can move correctly `correct_output_scale`
|
||||
(which is a nn.Parameter, not a Module) between devices when offloading. It is otherwise only used in
|
||||
`scale_corrected_output`
|
||||
"""
|
||||
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
||||
|
||||
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
||||
"""Scales the provided 3D tensor of shape [batch_size, num_tokens, hidden_size]."""
|
||||
return (corrected.type_as(self.correct_output_scale) * self.correct_output_scale).type_as(corrected)
|
||||
return self.forward(corrected)
|
||||
|
||||
|
||||
class Gemma3nTextRotaryEmbedding(Gemma2RotaryEmbedding):
|
||||
@ -1732,7 +1740,7 @@ class Gemma3nTextAttention(Gemma3Attention):
|
||||
self.v_norm = Gemma3nRMSNorm(dim=config.head_dim, eps=config.rms_norm_eps, with_scale=False)
|
||||
|
||||
first_kv_shared_layer_idx = self.config.num_hidden_layers - self.config.num_kv_shared_layers
|
||||
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx
|
||||
self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
|
||||
# Find the index of the last sliding or full layer before sharing starts (or None if no sharing)
|
||||
layer_type = config.layer_types[layer_idx]
|
||||
self.kv_shared_layer_index = (
|
||||
@ -1761,21 +1769,22 @@ class Gemma3nTextAttention(Gemma3Attention):
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None:
|
||||
# HybridCache has complex slicing when layer_type == "sliding_attention" that impact Shared KV Cache.
|
||||
# Device of past layer may be different from current one
|
||||
indices = cache_position.to(past_key_value.key_cache[self.kv_shared_layer_index].device)
|
||||
# In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond)
|
||||
if isinstance(past_key_value, HybridCache) and self.is_sliding:
|
||||
max_length = past_key_value.sliding_window
|
||||
if cache_position.shape[0] > max_length:
|
||||
# If in the prefill phase for a "sliding_attention" layer and the prefill is larger than the cache,
|
||||
# slice into the entire cache.
|
||||
indices = slice(0, max_length)
|
||||
else:
|
||||
# If prefill fits or generating for a "sliding_attention" layer, clamp to max_cache_len - 1
|
||||
indices = cache_position.clamp(min=0, max=max_length - 1)
|
||||
else:
|
||||
indices = cache_position
|
||||
indices = (
|
||||
slice(0, max_length)
|
||||
if cache_position.shape[0] > max_length
|
||||
else cache_position.clamp(min=0, max=max_length - 1)
|
||||
)
|
||||
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices]
|
||||
# Device of past layer may be different from current one
|
||||
key_states = past_key_value.key_cache[self.kv_shared_layer_index][:, :, indices].to(query_states.device)
|
||||
value_states = past_key_value.value_cache[self.kv_shared_layer_index][:, :, indices].to(
|
||||
query_states.device
|
||||
)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
key_states = self.k_norm(key_states)
|
||||
@ -1880,10 +1889,9 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
|
||||
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
|
||||
corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
|
||||
|
||||
first_prediction = corrected_predictions[self.config.altup_active_idx]
|
||||
first_prediction_clone = first_prediction.clone()
|
||||
first_prediction = corrected_predictions[self.config.altup_active_idx].clone()
|
||||
if self.config.altup_correct_scale:
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction_clone)
|
||||
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
||||
|
||||
# per_layer_input_gate adapted from jax.numpy.einsum("btd,dp->btp", ...)
|
||||
first_prediction = self.per_layer_input_gate(first_prediction)
|
||||
@ -1906,7 +1914,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer):
|
||||
class Gemma3nPreTrainedModel(Gemma2PreTrainedModel):
|
||||
config_class = Gemma3nConfig
|
||||
base_model_prefix = ""
|
||||
_no_split_modules = ["Gemma3nDecoderLayer"]
|
||||
_no_split_modules = ["Gemma3nTextDecoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Gemma2 isn't meant for training from scratch - only
|
||||
@ -1995,7 +2003,9 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
per_layer_projection: torch.Tensor = self.per_layer_model_projection(inputs_embeds)
|
||||
per_layer_projection *= self.per_layer_projection_scale.type(inputs_embeds.dtype)
|
||||
per_layer_projection *= self.per_layer_projection_scale.to(
|
||||
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
||||
)
|
||||
per_layer_projection = per_layer_projection.reshape(
|
||||
*inputs_embeds.shape[:-1],
|
||||
self.config.num_hidden_layers,
|
||||
@ -2010,7 +2020,9 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings.
|
||||
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
|
||||
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.type(inputs_embeds.dtype)
|
||||
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale.to(
|
||||
dtype=inputs_embeds.dtype, device=per_layer_projection.device
|
||||
)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
@ -2091,18 +2103,17 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
position_embeddings_local = self.rotary_emb_local(hidden_states_0, position_ids)
|
||||
|
||||
# Expand hidden_states to support per-layer inputs
|
||||
target_magnitude: torch.Tensor = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
|
||||
epsilon_tensor = torch.tensor(torch.finfo().min)
|
||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1, keepdim=True) ** 0.5
|
||||
epsilon_tensor = torch.tensor(1e-5)
|
||||
|
||||
temp_hidden_states = [hidden_states_0]
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
# altup_proj adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
altup_proj: torch.Tensor = self.altup_projections[i - 1](hidden_states_0)
|
||||
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
||||
current_hidden_state = current_hidden_state * (
|
||||
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
||||
)
|
||||
altup_proj = self.altup_projections[i - 1](hidden_states_0)
|
||||
current_hidden_state = altup_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
||||
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
|
||||
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
|
||||
temp_hidden_states.append(current_hidden_state)
|
||||
|
||||
hidden_states = torch.stack(temp_hidden_states, dim=0) # [num_altup_inputs, batch, seq_len, hidden_size]
|
||||
@ -2120,9 +2131,9 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings_global=position_embeddings_global,
|
||||
position_embeddings_local=position_embeddings_local,
|
||||
per_layer_input=per_layer_input,
|
||||
position_embeddings_global,
|
||||
position_embeddings_local,
|
||||
per_layer_input,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
@ -2147,11 +2158,10 @@ class Gemma3nTextModel(Gemma3TextModel):
|
||||
for i in range(1, self.config.altup_num_inputs):
|
||||
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
||||
altup_unemb_proj: torch.Tensor = self.altup_unembed_projections[i - 1](hidden_states[i])
|
||||
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
||||
current_hidden_state = current_hidden_state * (
|
||||
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
||||
)
|
||||
current_hidden_state = altup_unemb_proj.to(dtype=hidden_states_0.dtype, device=target_magnitude.device)
|
||||
new_magnitude = torch.mean(current_hidden_state**2, dim=-1, keepdim=True)
|
||||
new_magnitude = torch.sqrt(torch.maximum(new_magnitude, epsilon_tensor.to(target_magnitude.device)))
|
||||
current_hidden_state = current_hidden_state * target_magnitude / new_magnitude
|
||||
temp_hidden_states.append(current_hidden_state)
|
||||
|
||||
hidden_states = torch.stack(temp_hidden_states)
|
||||
|
@ -39,6 +39,7 @@ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
torch_int,
|
||||
)
|
||||
@ -770,6 +771,7 @@ class GitVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -836,8 +838,6 @@ class GitVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_hubert import HubertConfig
|
||||
|
||||
|
||||
@ -300,6 +301,7 @@ class HubertAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -307,7 +309,7 @@ class HubertAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -328,42 +330,9 @@ class HubertAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -385,7 +354,7 @@ class HubertAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class HubertFeedForward(nn.Module):
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from .configuration_idefics import IdeficsVisionConfig
|
||||
@ -351,6 +352,7 @@ class IdeficsVisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([IdeficsVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -417,8 +419,6 @@ class IdeficsVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -451,6 +451,7 @@ class Kosmos2VisionEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -517,8 +518,6 @@ class Kosmos2VisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""LayoutLM model configuration"""
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
@ -130,10 +131,22 @@ class LayoutLMConfig(PretrainedConfig):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self._position_embedding_type = position_embedding_type
|
||||
self.use_cache = use_cache
|
||||
self.max_2d_position_embeddings = max_2d_position_embeddings
|
||||
|
||||
@property
|
||||
def position_embedding_type(self):
|
||||
warnings.warn(
|
||||
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._position_embedding_type
|
||||
|
||||
@position_embedding_type.setter
|
||||
def position_embedding_type(self, value):
|
||||
self._position_embedding_type = value
|
||||
|
||||
|
||||
class LayoutLMOnnxConfig(OnnxConfig):
|
||||
def __init__(
|
||||
|
@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch LayoutLM model."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,16 +24,17 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_layoutlm import LayoutLMConfig
|
||||
|
||||
|
||||
@ -120,9 +120,37 @@ class LayoutLMEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM
|
||||
class LayoutLMSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -130,6 +158,7 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -139,20 +168,12 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -162,96 +183,33 @@ class LayoutLMSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -270,18 +228,11 @@ class LayoutLMSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
LAYOUTLM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": LayoutLMSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM
|
||||
class LayoutLMAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = LayoutLMSelfAttention(config)
|
||||
self.output = LayoutLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -303,6 +254,9 @@ class LayoutLMAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -312,15 +266,14 @@ class LayoutLMAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -358,22 +311,19 @@ class LayoutLMOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM
|
||||
class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = LayoutLMAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = LayoutLMIntermediate(config)
|
||||
self.output = LayoutLMOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -383,60 +333,23 @@ class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -445,14 +358,19 @@ class LayoutLMLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->LayoutLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM
|
||||
class LayoutLMEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -465,65 +383,36 @@ class LayoutLMEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -648,6 +537,9 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -663,7 +555,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
|
||||
Bounding boxes of each input sequence tokens. Selected in the range `[0,
|
||||
@ -756,20 +648,16 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -796,6 +684,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -871,11 +762,9 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -889,10 +778,6 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
labels.view(-1),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
loss=masked_lm_loss,
|
||||
logits=prediction_scores,
|
||||
@ -921,6 +806,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -996,7 +882,7 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@ -1026,9 +912,6 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
@ -1059,6 +942,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1132,7 +1016,7 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1145,10 +1029,6 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@ -1176,6 +1056,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.layoutlm.embeddings.word_embeddings
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1253,7 +1134,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1280,10 +1161,6 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
|
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""MarkupLM model configuration"""
|
||||
|
||||
import warnings
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
@ -141,7 +143,7 @@ class MarkupLMConfig(PretrainedConfig):
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self._position_embedding_type = position_embedding_type
|
||||
self.use_cache = use_cache
|
||||
self.classifier_dropout = classifier_dropout
|
||||
# additional properties
|
||||
@ -152,5 +154,17 @@ class MarkupLMConfig(PretrainedConfig):
|
||||
self.subs_pad_id = subs_pad_id
|
||||
self.xpath_unit_hidden_size = xpath_unit_hidden_size
|
||||
|
||||
@property
|
||||
def position_embedding_type(self):
|
||||
warnings.warn(
|
||||
"The `position_embedding_type` attribute is deprecated and will be removed in v4.55.",
|
||||
FutureWarning,
|
||||
)
|
||||
return self._position_embedding_type
|
||||
|
||||
@position_embedding_type.setter
|
||||
def position_embedding_type(self, value):
|
||||
self._position_embedding_type = value
|
||||
|
||||
|
||||
__all__ = ["MarkupLMConfig"]
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch MarkupLM model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -26,20 +25,22 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
MaskedLMOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
ALL_ATTENTION_FUNCTIONS,
|
||||
PreTrainedModel,
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
)
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils import auto_docstring, can_return_tuple, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_markuplm import MarkupLMConfig
|
||||
|
||||
|
||||
@ -326,9 +327,37 @@ class MarkupLMOnlyMLMHead(nn.Module):
|
||||
return prediction_scores
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM
|
||||
class MarkupLMSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -336,6 +365,7 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -345,20 +375,12 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -368,111 +390,41 @@ class MarkupLMSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in MarkupLMModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
MARKUPLM_SELF_ATTENTION_CLASSES = {
|
||||
"eager": MarkupLMSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->MarkupLM,BERT->MARKUPLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM
|
||||
class MarkupLMAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = MARKUPLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = MarkupLMSelfAttention(config)
|
||||
self.output = MarkupLMSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -494,6 +446,9 @@ class MarkupLMAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -503,37 +458,33 @@ class MarkupLMAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM
|
||||
class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = MarkupLMAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = MarkupLMAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = MarkupLMIntermediate(config)
|
||||
self.output = MarkupLMOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -543,60 +494,23 @@ class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -605,14 +519,19 @@ class MarkupLMLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->MarkupLM
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM
|
||||
class MarkupLMEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([MarkupLMLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -625,65 +544,36 @@ class MarkupLMEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -749,6 +639,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -763,7 +654,7 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
) -> Union[tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
|
||||
Tag IDs for each token in the input sequence, padded up to config.max_depth.
|
||||
@ -839,21 +730,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache
|
||||
@ -879,6 +765,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -939,7 +826,7 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -966,10 +853,6 @@ class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
||||
if not return_dict:
|
||||
output = (start_logits, end_logits) + outputs[2:]
|
||||
return ((total_loss,) + output) if total_loss is not None else output
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=total_loss,
|
||||
start_logits=start_logits,
|
||||
@ -1000,6 +883,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1058,7 +942,7 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@ -1072,10 +956,6 @@ class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
|
||||
labels.view(-1),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=prediction_scores,
|
||||
@ -1107,6 +987,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -1164,7 +1045,7 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@ -1194,9 +1075,6 @@ class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
|
@ -354,11 +354,6 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -367,11 +362,11 @@ class MaskFormerSwinSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
@ -182,7 +182,6 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen
|
||||
class MusicgenAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -189,7 +189,7 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->MusicgenMelody
|
||||
class MusicgenMelodyAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -503,7 +503,7 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->NllbMoe,key_value_states->encoder_hidden_states
|
||||
class NllbMoeAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_patchtsmixer import PatchTSMixerConfig
|
||||
|
||||
|
||||
@ -303,6 +304,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -310,7 +312,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -331,42 +333,9 @@ class PatchTSMixerAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -388,7 +357,7 @@ class PatchTSMixerAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class PatchMixerBlock(nn.Module):
|
||||
|
@ -28,6 +28,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_patchtst import PatchTSTConfig
|
||||
|
||||
|
||||
@ -100,6 +101,7 @@ class PatchTSTAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -107,7 +109,7 @@ class PatchTSTAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -128,42 +130,9 @@ class PatchTSTAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -185,7 +154,7 @@ class PatchTSTAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class PatchTSTBatchNorm(nn.Module):
|
||||
|
@ -37,6 +37,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassif
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_sew import SEWConfig
|
||||
|
||||
|
||||
@ -293,6 +294,7 @@ class SEWAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -300,7 +302,7 @@ class SEWAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -321,42 +323,9 @@ class SEWAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -378,7 +347,7 @@ class SEWAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class SEWFeedForward(nn.Module):
|
||||
|
@ -205,7 +205,7 @@ def eager_attention_forward(
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text
|
||||
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenAttention with Musicgen->Speech2Text
|
||||
class Speech2TextAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
|
@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Splinter model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -25,13 +24,19 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
ModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_splinter import SplinterConfig
|
||||
|
||||
|
||||
@ -64,7 +69,6 @@ class SplinterEmbeddings(nn.Module):
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values_length: Optional[int] = 0,
|
||||
) -> tuple:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
@ -74,7 +78,7 @@ class SplinterEmbeddings(nn.Module):
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
@ -92,9 +96,37 @@ class SplinterEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter
|
||||
# Copied from transformers.models.align.modeling_align.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->Splinter
|
||||
class SplinterSelfAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@ -102,6 +134,7 @@ class SplinterSelfAttention(nn.Module):
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
@ -111,20 +144,12 @@ class SplinterSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = position_embedding_type or getattr(
|
||||
config, "position_embedding_type", "absolute"
|
||||
)
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
self.attention_dropout = config.attention_probs_dropout_prob
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -134,96 +159,33 @@ class SplinterSelfAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys
|
||||
# and values come from an encoder; the attention mask needs to be
|
||||
# such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_value[0]
|
||||
value_layer = past_key_value[1]
|
||||
attention_mask = encoder_attention_mask
|
||||
elif is_cross_attention:
|
||||
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||
attention_mask = encoder_attention_mask
|
||||
elif past_key_value is not None:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
head_mask=head_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
use_cache = past_key_value is not None
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||||
if use_cache:
|
||||
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||||
-1, 1
|
||||
)
|
||||
else:
|
||||
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||||
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||||
distance = position_ids_l - position_ids_r
|
||||
|
||||
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||||
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||||
|
||||
if self.position_embedding_type == "relative_key":
|
||||
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores
|
||||
elif self.position_embedding_type == "relative_key_query":
|
||||
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||||
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||||
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||||
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in SplinterModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
@ -242,18 +204,11 @@ class SplinterSelfOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
SPLINTER_SELF_ATTENTION_CLASSES = {
|
||||
"eager": SplinterSelfAttention,
|
||||
}
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter,BERT->SPLINTER
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->Splinter
|
||||
class SplinterAttention(nn.Module):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self = SPLINTER_SELF_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config, position_embedding_type=position_embedding_type
|
||||
)
|
||||
self.self = SplinterSelfAttention(config)
|
||||
self.output = SplinterSelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@ -275,6 +230,9 @@ class SplinterAttention(nn.Module):
|
||||
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -284,15 +242,14 @@ class SplinterAttention(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
self_outputs = self.self(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(self_outputs[0], hidden_states)
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
@ -330,22 +287,19 @@ class SplinterOutput(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->Splinter
|
||||
class SplinterLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = SplinterAttention(config)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = SplinterAttention(config, position_embedding_type="absolute")
|
||||
self.intermediate = SplinterIntermediate(config)
|
||||
self.output = SplinterOutput(config)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -355,60 +309,23 @@ class SplinterLayer(GradientCheckpointingLayer):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
|
||||
# if decoder, the last output is tuple of self-attn cache
|
||||
if self.is_decoder:
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
else:
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
cross_attn_present_key_value = None
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||
cross_attention_outputs = self.crossattention(
|
||||
attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
cross_attn_past_key_value,
|
||||
output_attentions,
|
||||
)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||
|
||||
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||
present_key_value = present_key_value + cross_attn_present_key_value
|
||||
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
# if decoder, return the attn key/values as the last output
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
@ -417,14 +334,19 @@ class SplinterLayer(GradientCheckpointingLayer):
|
||||
return layer_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter
|
||||
# Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->Splinter
|
||||
class SplinterEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layer = nn.ModuleList([SplinterLayer(config) for i in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -437,65 +359,36 @@ class SplinterEncoder(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.Tensor], BaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@ -554,6 +447,11 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@deprecate_kwarg("encoder_hidden_states", version="4.54.0")
|
||||
@deprecate_kwarg("encoder_attention_mask", version="4.54.0")
|
||||
@deprecate_kwarg("past_key_values", version="4.54.0")
|
||||
@deprecate_kwarg("use_cache", version="4.54.0")
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -570,7 +468,7 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
) -> Union[tuple, BaseModelOutput]:
|
||||
r"""
|
||||
token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
||||
@ -592,11 +490,6 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
@ -610,11 +503,8 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
@ -622,17 +512,6 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@ -645,31 +524,21 @@ class SplinterModel(SplinterPreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
return_dict=True,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output,) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=sequence_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
|
@ -435,11 +435,6 @@ class SwinSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -448,11 +443,11 @@ class SwinSelfAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> tuple[torch.Tensor]:
|
||||
batch_size, dim, num_channels = hidden_states.shape
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
hidden_shape = (batch_size, dim, -1, self.attention_head_size)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
|
@ -45,6 +45,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_unispeech import UniSpeechConfig
|
||||
|
||||
|
||||
@ -332,6 +333,7 @@ class UniSpeechAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -339,7 +341,7 @@ class UniSpeechAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -360,42 +362,9 @@ class UniSpeechAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -417,7 +386,7 @@ class UniSpeechAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class UniSpeechFeedForward(nn.Module):
|
||||
|
@ -47,6 +47,7 @@ from ...modeling_outputs import (
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_unispeech_sat import UniSpeechSatConfig
|
||||
|
||||
|
||||
@ -337,6 +338,7 @@ class UniSpeechSatAttention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -344,7 +346,7 @@ class UniSpeechSatAttention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -365,42 +367,9 @@ class UniSpeechSatAttention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -422,7 +391,7 @@ class UniSpeechSatAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class UniSpeechSatFeedForward(nn.Module):
|
||||
|
@ -55,6 +55,7 @@ from ...utils import (
|
||||
is_torch_flex_attn_available,
|
||||
logging,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||
|
||||
|
||||
@ -524,6 +525,7 @@ class Wav2Vec2Attention(nn.Module):
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
@deprecate_kwarg("past_key_value", version="4.54.0")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -531,7 +533,7 @@ class Wav2Vec2Attention(nn.Module):
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
# TODO: we need a refactor so that the different attention modules can get their specific kwargs
|
||||
# ATM, we have mixed things encoder, decoder, and encoder-decoder attn
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
@ -552,42 +554,9 @@ class Wav2Vec2Attention(nn.Module):
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
|
||||
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
current_states = key_value_states if is_cross_attention else hidden_states
|
||||
key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
@ -609,7 +578,7 @@ class Wav2Vec2Attention(nn.Module):
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output, attn_weights, None
|
||||
|
||||
|
||||
class Wav2Vec2FeedForward(nn.Module):
|
||||
|
@ -336,6 +336,11 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
if num_input_ids is not None:
|
||||
weights = weights[:, :, num_input_ids:, :]
|
||||
|
||||
# Since we ignore `decoder_input_ids` in the DTW and in the case where we generated only one token (for which we don't have cross attentions, see below comments),
|
||||
# the DTW sequence length is 0 and we should return only 0.0s for the token timestamps
|
||||
if weights.shape[2] == 0:
|
||||
return timestamps
|
||||
|
||||
if num_frames is None or isinstance(num_frames, int):
|
||||
# Normalize and smoothen the weights.
|
||||
std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
@ -366,9 +371,12 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] * time_precision
|
||||
|
||||
# each predicted token has a corresponding timestamp, expect the eos token for which we don't retrieve cross attentions
|
||||
# each predicted token has a corresponding timestamp, expect the eos token (or last predicted token) for which we don't retrieve cross attentions
|
||||
# (indeed contrary to OAI that re-run a full foward to retreive cross attentions for each token and therefore also the last one predicted, we retreive
|
||||
# cross attentions directly from the auto-regressive generation, so we don't have cross attentiosn for the token at the end of the sequence. Nevertheless,
|
||||
# that is not important since we expect this last token to be the eos token)
|
||||
# 1. for decoder_input_ids, we set the timestamps to 0.0
|
||||
# 2. for the eos token, we simply duplicate the timestamp of the last non-eos token
|
||||
# 2. for the eos token (or last predicted token), we simply duplicate the timestamp of the last non-eos token
|
||||
timestamps[batch_idx] = torch.cat(
|
||||
[torch.zeros(num_input_ids), torch.tensor(jump_times), torch.tensor([jump_times[-1]])]
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
auto_docstring,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
torch_int,
|
||||
)
|
||||
@ -576,6 +577,7 @@ class XCLIPEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([XCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
@ -642,8 +644,6 @@ class XCLIPEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
@ -1642,7 +1642,6 @@ def set_model_tester_for_less_flaky_test(test_case):
|
||||
"AriaVisionText2TextModelTester",
|
||||
"GPTNeoModelTester",
|
||||
"DPTModelTester",
|
||||
"Gemma3nTextModelTester", # cannot have a single layer combined with the cache sharing config attrs in the tester
|
||||
]
|
||||
if test_case.model_tester.__class__.__name__ in exceptional_classes:
|
||||
target_num_hidden_layers = None
|
||||
|
@ -297,7 +297,7 @@ class AltCLIPTextModelTester:
|
||||
@require_torch
|
||||
class AltCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (AltCLIPTextModel,) if is_torch_available() else ()
|
||||
fx_compatible = True
|
||||
fx_compatible = False # Cannot support if `can_return_tuple`
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
|
||||
@ -411,7 +411,7 @@ def prepare_img():
|
||||
class AltCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (AltCLIPModel,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"feature-extraction": AltCLIPModel} if is_torch_available() else {}
|
||||
fx_compatible = True
|
||||
fx_compatible = False # Cannot support if `can_return_tuple`
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
@ -39,13 +39,20 @@ from transformers.testing_utils import (
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
_test_eager_matches_sdpa_inference,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
from ..gemma.test_modeling_gemma import GemmaModelTester
|
||||
|
||||
|
||||
@ -256,6 +263,7 @@ class Gemma3nTextModelTester(GemmaModelTester):
|
||||
vocab_size=99,
|
||||
vocab_size_per_layer_input=99,
|
||||
hidden_size=16,
|
||||
hidden_size_per_layer_input=16,
|
||||
num_hidden_layers=4, # override to correctly test sharing cache pattern
|
||||
num_kv_shared_layers=2, # important to override
|
||||
layer_types=[
|
||||
@ -291,6 +299,7 @@ class Gemma3nTextModelTester(GemmaModelTester):
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_size_per_layer_input = vocab_size_per_layer_input
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_size_per_layer_input = hidden_size_per_layer_input
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_kv_shared_layers = num_kv_shared_layers
|
||||
self.layer_types = layer_types
|
||||
@ -317,7 +326,6 @@ class Gemma3nTextModelTester(GemmaModelTester):
|
||||
for_causal_lm_class = Gemma3nForCausalLM
|
||||
|
||||
|
||||
@unittest.skip("Skipped for now!")
|
||||
@require_torch
|
||||
class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Gemma3nTextModel, Gemma3nForCausalLM) if is_torch_available() else ()
|
||||
@ -365,6 +373,64 @@ class Gemma3nTextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
|
||||
[expected_shape] * len(iter_hidden_states),
|
||||
)
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@require_torch_sdpa
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self,
|
||||
name,
|
||||
torch_dtype,
|
||||
padding_side,
|
||||
use_attention_mask,
|
||||
output_attentions,
|
||||
enable_kernels,
|
||||
):
|
||||
"We need to relax a bit the `atols` for fp32 here due to the altup projections"
|
||||
atols = {
|
||||
("cpu", False, torch.float32): 1e-3, # this was relaxed
|
||||
("cpu", False, torch.float16): 5e-3,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-3, # this was relaxed
|
||||
("cpu", True, torch.float16): 5e-3,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-3, # this was relaxed
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 5e-3,
|
||||
("cuda", True, torch.float32): 1e-3, # this was relaxed
|
||||
("cuda", True, torch.bfloat16): 1e-2,
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
_test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels, atols=atols
|
||||
)
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(
|
||||
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
|
||||
)
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(
|
||||
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
|
||||
)
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(
|
||||
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with contrastive decoding"
|
||||
)
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(
|
||||
"Gemma3n has a special shape for hidden states (due to per-layer projs) which is not compatible with dola decoding"
|
||||
)
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
|
||||
class Gemma3nVision2TextModelTester:
|
||||
text_config = {"activation_sparsity_pattern": None}
|
||||
|
@ -243,7 +243,7 @@ class LayoutLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
fx_compatible = True
|
||||
fx_compatible = False # Cannot support if `can_return_tuple`
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LayoutLMModelTester(self)
|
||||
|
@ -372,6 +372,18 @@ class SplinterModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
@unittest.skip(
|
||||
"Splinter GC with `use_reentrant` fails after #38751, FIXME raushan after deprecated args are removed"
|
||||
)
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Splinter GC with `use_reentrant` fails after #38751, FIXME raushan after deprecated args are removed"
|
||||
)
|
||||
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class SplinterModelIntegrationTest(unittest.TestCase):
|
||||
|
@ -1797,7 +1797,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
|
||||
torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
@ -1971,16 +1971,11 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
},
|
||||
{
|
||||
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
|
||||
# "timestamp": (39.80, 45.36),
|
||||
# above is the expected output on A100.
|
||||
# on CI T4s, due to sligth difference in floating points operations, expected is below
|
||||
"timestamp": (39.80, 45.38),
|
||||
"timestamp": (39.80, 45.36),
|
||||
},
|
||||
{
|
||||
"text": " can discover in it but little of rocky Ithaca.",
|
||||
# "timestamp": (45.36, 49.0),
|
||||
# see above
|
||||
"timestamp": (45.38, 49.0),
|
||||
"timestamp": (45.36, 49.0),
|
||||
},
|
||||
{
|
||||
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
|
||||
@ -2220,7 +2215,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
torch.tensor([44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400, 50.5400]),
|
||||
torch.tensor([50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400, 52.9600]),
|
||||
torch.tensor([52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1600, 58.5200, 58.6400, 58.8200, 59.4200, 59.4200]),
|
||||
torch.tensor([58.6800, 59.1400, 59.5400, 59.9200, 60.1400, 60.3800, 60.8400, 61.6000, 62.2400, 62.4200, 62.4200])
|
||||
torch.tensor([58.6800, 59.1400, 59.5400, 59.9200, 60.1400, 60.3800, 60.8400, 61.6000, 62.2400, 62.3800, 62.4400])
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
@ -2894,10 +2889,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
" Folks, if you watch the show, you know I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories, developing the central headline pawns, definitely maneuvering an oh-so-topical night to F6, faming of classic Sicilian, named or variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a Fisher shows in lip-nitsky attack that culminates in the elegant lethal slow-played all-pass on checkmate that is my nightly monologue, but sometimes sometimes, sometimes folks I sometimes I start a little wake-up side down in the monkey bars of a condemned playground on a super fun site, get all hept up on goofballs, rummage that would discard a tag bag of defective toys, yank out a fistball of disembodied doll limbs, toss them on a stain kid's place mad from a defunct denies, set up a table inside a rusty cargo container down by the warf and challenge toothless drifters to the godless bughouse blitz of tournament that is my segment.",
|
||||
" Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing on those topical anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush. To create the luxury sedan that is my nightly monologue, but sometimes I just sometimes folks, I lurched to consciousness in the back of an abandoned school bus and slapped myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen-moon render a gas tank out of an empty big gulp, filled with white claw and de-natured alcohol, then light a match, letter-rip, and the dis-mented one-man soapbox derby of news that is my segment. Meanwhile.",
|
||||
" Ladies and gentlemen, you know, I spent a lot of time right over there, raising the finest hosting news cattle firmly, yet tenderly milking the latest headlines from their jokes, swollen teats, churning the daily stories into the decadent Provincil style triple cream-breed. It is my nightly monologue, but sometimes sometimes I stagger home hungry after being released by the police and root around in the neighbor's trash can for an old milk carton scraped out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-dawn street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire than a hunker down in hallucinate while eating the Listeria latent demon custard of news that is my segment.",
|
||||
" Folks, you watched this show. You know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Icol Greg Waferandi, who carefully die them in a pallet of bright, zesty shades, and adorn them in the finest, most topical inlay work, using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddle stitching, and line it with bees, wax, coated linen, finally attach a mallet hammered strap, pearl hardware, and close-shet to create for you the one-of-a-kind, hout-cout-tour, earned me his burkin bag that is my monologue, but sometimes, sometimes, folks. Sometimes, sometimes, sometimes, sometimes I wake up in the last car of an abandoned roller coaster at Coney Island, where I'm hiding from the triads, I huff some engine lubricants out of a safe way bag, and staggered down the shore to tear the sail off a beach skoener, then I ripped the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel, lovely folks. And use it to stitch the sail into a loose pouch-like rock sack, and I stow in the back of a garbage truck to the junkyard, where I pick through to the debris for only the broken toys that make me the saddest, until I have loaded for you. The hobo fugitives bug out bindle of news that is my segment. Meanwhile!",
|
||||
" Folks, you watched this show. You know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Icol Greg Waferandi, who carefully die them in a pallet of bright, zesty shades, and adorn them in the finest, most topical inlay work, using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddle stitching, and line it with bees, wax, coated linen, finally attach a mallet hammered strap, pearl hardware, and close-shet to create for you the one-of-a-kind, hout-cout-tour, earned me his burkin bag that is my monologue, but sometimes, sometimes, folks. Sometimes, sometimes, sometimes, sometimes I wake up in the last car of an abandoned roller coaster at Coney Island, where I'm hiding from the triads, I huff some engine lubricants out of a safe way bag, and staggered down the shore to tear the sail off a beach skoener, then I ripped the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel, lovely folks. And use it to stitch the sail into a loose pouch-like rock sack, and I stow in the back of a garbage truck to the junkyard, where I pick through to the debris for only the broken toys that make me the saddest, until I have loaded for you. The hobo fugitives bug out bindle of news that is my segment. Meanwhile.",
|
||||
" You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui, to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue, but sometimes just sometimes, I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself and use fry oil, wrap my hands and some old duct tape I stole from a broken car window, pound a six pack of blueberry hardcelser and a sack of pills I stole from a parked ambulance, then arm wrestle a raccoon in the back alley vision quest of news that is my segment.",
|
||||
" You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press, black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards and a face plate, and finally using fluted strips of white alloyed molding I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes, folks. Sometimes, just sometimes, I come to my senses fully naked on the deck of a pirate, besieged, melee, container ship that picked me up floating on the detached door of a port of potty in the Indian Ocean. Then, after a sunstroke induced realization of the crew of this ship plans to sell me and exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe and a pool chain that accepting my new role as captain and declaring myself King of the Windark Seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create these shopping wet pirate crown of news that is my segment. Meanwhile!",
|
||||
" Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks, I wake up in the baggage hole of Greyhound bus. It's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants. As ovenmets to extract and serve the demented transience pound cake of news that is my segment.",
|
||||
' You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press, black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards and a face plate, and finally using fluted strips of white alloyed molding I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes, folks. Sometimes, just sometimes, I come to my senses fully naked on the deck of a pirate, besieged, melee, container ship that picked me up floating on the detached door of a port of potty in the Indian Ocean. Then, after a sunstroke induced realization of the crew of this ship plans to sell me and exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe and a pool chain that accepting my new role as captain and declaring myself King of the Windark Seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create these shopping wet pirate crown of news that is my segment. Meanwhile, young man.',
|
||||
" Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks, I wake up in the baggage hole of Greyhound bus. It's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants. As ovenmets to extract and serve the Demented Transience pound cake of news that is my segment.",
|
||||
" Folks, if you watch the show and I hope you do, I spend a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines, working with the best trainers money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen that is my nightly monologue. But sometimes sometimes folks I break into an unincorporated veterinary genetics lab. And grab whatever test tubes I can find and then under a grow light I got from it a discarded chia pet. I mixed the pill for DNA of a horse and whatever was in a tube labeled Keith Cole and extra. Slering the concoction with caffeine pills and a microwave bread bowl, I screamed sing a prayer to Janice initiator of human life and God of transformation as a half horse, half man freak, seizes to life before me and the hideous collection of loose animal parts and corrupted men tissue that is my segment. Meanwhile.",
|
||||
]
|
||||
# fmt: on
|
||||
@ -2935,6 +2930,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
"renormalize_logits": True, # necessary to match OAI beam search implementation
|
||||
}
|
||||
|
||||
set_seed(0)
|
||||
result = model.generate(**inputs, **gen_kwargs)
|
||||
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
|
||||
|
||||
|
@ -156,6 +156,334 @@ TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION = [
|
||||
] + [("fp32_pad_left_output_attentions", "fp32", "left", True, True, False)]
|
||||
|
||||
|
||||
def _test_eager_matches_sdpa_inference(
|
||||
self,
|
||||
name,
|
||||
torch_dtype,
|
||||
padding_side,
|
||||
use_attention_mask,
|
||||
output_attentions,
|
||||
enable_kernels,
|
||||
atols=None,
|
||||
rtols=None,
|
||||
):
|
||||
"""
|
||||
This test is written as a regular function to be able to overload it easily with different tolerances.
|
||||
Otherwise, `paramterezie.expand` prevents it as it removes the original function from the namespace.
|
||||
"""
|
||||
# TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like
|
||||
# models have a custom mixin, which we detect to skip this test.
|
||||
if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__):
|
||||
self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`")
|
||||
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not self.all_model_classes[0]._supports_sdpa:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
# convert shorthand name to torch.dtype
|
||||
if torch_dtype == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif torch_dtype == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
|
||||
if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16:
|
||||
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||
|
||||
if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16:
|
||||
self.skipTest(
|
||||
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
|
||||
)
|
||||
|
||||
# Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype)
|
||||
if atols is None:
|
||||
atols = {
|
||||
("cpu", False, torch.float32): 1e-6,
|
||||
("cpu", False, torch.float16): 5e-3,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-6,
|
||||
("cpu", True, torch.float16): 5e-3,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-6,
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 5e-3,
|
||||
("cuda", True, torch.float32): 1e-6,
|
||||
("cuda", True, torch.bfloat16): 1e-2,
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
if rtols is None:
|
||||
rtols = {
|
||||
("cpu", False, torch.float32): 1e-4,
|
||||
("cpu", False, torch.float16): 5e-3,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-4,
|
||||
("cpu", True, torch.float16): 5e-3,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-4,
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 5e-3,
|
||||
("cuda", True, torch.float32): 1e-4,
|
||||
("cuda", True, torch.bfloat16): 3e-2, # (different from others)
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
set_config_for_less_flaky_test(config)
|
||||
model = model_class(config)
|
||||
# TODO: standardize the interfaces for musicgen models, see other todo in this test
|
||||
if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration":
|
||||
is_encoder_decoder = True
|
||||
else:
|
||||
is_encoder_decoder = model.config.is_encoder_decoder
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_from_pretrained_kwargs = {
|
||||
"pretrained_model_name_or_path": tmpdirname,
|
||||
"torch_dtype": torch_dtype,
|
||||
}
|
||||
|
||||
if hasattr(config, "use_mask_token") or "use_mask_token" in inspect.signature(model.__init__).parameters:
|
||||
model_from_pretrained_kwargs["use_mask_token"] = True
|
||||
|
||||
# TODO: remove this try/except, models should have a shared API
|
||||
try:
|
||||
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="sdpa")
|
||||
except ValueError:
|
||||
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
|
||||
|
||||
model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
|
||||
|
||||
set_model_for_less_flaky_test(model_eager)
|
||||
set_model_for_less_flaky_test(model_sdpa)
|
||||
|
||||
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
|
||||
if not (self.has_attentions and can_output_attn) and output_attentions:
|
||||
self.skipTest(reason="Model does not support output_attentions")
|
||||
|
||||
# TODO: if we can also check with `batch_size=1` without being flaky?
|
||||
for batch_size in [7]:
|
||||
# musicgen decoder models; TODO: find better abstraction
|
||||
if (
|
||||
model.__class__.__name__.startswith("Musicgen")
|
||||
and hasattr(self.model_tester, "num_codebooks")
|
||||
and not hasattr(model_eager, "text_encoder")
|
||||
):
|
||||
input_data_batch_size = batch_size * self.model_tester.num_codebooks
|
||||
else:
|
||||
input_data_batch_size = batch_size
|
||||
|
||||
processed_inputs = {}
|
||||
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
|
||||
|
||||
for key in getattr(self, "additional_model_inputs", []):
|
||||
# Some models don't have all `additional_model_inputs`, especially when we
|
||||
# craft cases to test model in different settings
|
||||
if key in inputs_dict:
|
||||
processed_inputs[key] = inputs_dict[key]
|
||||
|
||||
for key, value in processed_inputs.items():
|
||||
if torch.is_floating_point(value):
|
||||
value = value.to(torch_dtype)
|
||||
|
||||
# extend value to have at least `input_data_batch_size` elements
|
||||
if value.shape[0] < input_data_batch_size:
|
||||
size = (input_data_batch_size - value.shape[0], *value.shape[1:])
|
||||
if torch.is_floating_point(value):
|
||||
extension = torch.rand(size=size, dtype=value.dtype, device=torch_device)
|
||||
else:
|
||||
extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device)
|
||||
value = torch.cat((value, extension), dim=0).to(torch_device)
|
||||
|
||||
processed_inputs[key] = value[:input_data_batch_size]
|
||||
|
||||
if not use_attention_mask:
|
||||
dummy_attention_mask = None
|
||||
else:
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
if dummy_attention_mask is None:
|
||||
if is_encoder_decoder:
|
||||
seqlen = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name]).shape[
|
||||
-1
|
||||
]
|
||||
else:
|
||||
seqlen = processed_inputs[model.main_input_name].shape[-1]
|
||||
dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
|
||||
|
||||
# extend dummy_attention_mask to have at least `batch_size` elements
|
||||
if dummy_attention_mask.shape[0] < batch_size:
|
||||
size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:])
|
||||
extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device)
|
||||
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
|
||||
|
||||
dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device)
|
||||
|
||||
dummy_attention_mask[:] = 1
|
||||
if padding_side == "left":
|
||||
dummy_attention_mask[-1, :2] = 0
|
||||
dummy_attention_mask[-1, 2:] = 1
|
||||
elif padding_side == "right":
|
||||
dummy_attention_mask[-1, -2:] = 0
|
||||
dummy_attention_mask[-1, :-2] = 1
|
||||
|
||||
if is_encoder_decoder:
|
||||
# musicgen encoder-decoder models; TODO: find better abstraction
|
||||
if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"):
|
||||
input_data_batch_size = batch_size * self.model_tester.num_codebooks
|
||||
else:
|
||||
input_data_batch_size = batch_size
|
||||
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name])
|
||||
decoder_input_ids = decoder_input_ids[:input_data_batch_size]
|
||||
if decoder_input_ids.shape[0] != input_data_batch_size:
|
||||
extension = torch.ones(
|
||||
input_data_batch_size - decoder_input_ids.shape[0],
|
||||
*decoder_input_ids.shape[1:],
|
||||
dtype=decoder_input_ids.dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
# TODO: never an `attention_mask` arg here?
|
||||
processed_inputs.update(
|
||||
{
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
)
|
||||
else:
|
||||
processed_inputs.update(
|
||||
{
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Otherwise fails for e.g. WhisperEncoderModel
|
||||
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
|
||||
processed_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters:
|
||||
processed_inputs["output_attentions"] = output_attentions
|
||||
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
|
||||
dummy_mask = torch.ones((self.model_tester.num_masks,))
|
||||
|
||||
# In case of additional token (like class) we define a custom `mask_length`
|
||||
if hasattr(self.model_tester, "mask_length"):
|
||||
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
|
||||
else:
|
||||
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
|
||||
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
|
||||
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
|
||||
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
|
||||
|
||||
if "noise" in inspect.signature(model_eager.forward).parameters:
|
||||
np.random.seed(2)
|
||||
num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(batch_size, num_patches))
|
||||
processed_inputs["noise"] = torch.from_numpy(noise)
|
||||
|
||||
# TODO: test gradients as well (& for FA2 as well!)
|
||||
with torch.no_grad():
|
||||
with sdpa_kernel(
|
||||
enable_flash=enable_kernels,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=enable_kernels,
|
||||
):
|
||||
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
|
||||
prepared_inputs = {
|
||||
k: v.to(torch_device) if isinstance(v, torch.Tensor) else v for k, v in prepared_inputs.items()
|
||||
}
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
|
||||
if "logits_per_text" in outputs_eager:
|
||||
key = "logits_per_text"
|
||||
elif "vision_hidden_states" in outputs_eager:
|
||||
key = "vision_hidden_states"
|
||||
elif "audio_values" in outputs_eager:
|
||||
key = "audio_values"
|
||||
elif "decoder_hidden_states" in outputs_eager:
|
||||
key = "decoder_hidden_states"
|
||||
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
|
||||
key = "logits"
|
||||
elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower():
|
||||
outputs_eager = outputs_eager["language_model_outputs"]
|
||||
outputs_sdpa = outputs_sdpa["language_model_outputs"]
|
||||
key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states"
|
||||
else:
|
||||
key = "hidden_states"
|
||||
|
||||
# TODO: rename logits -> hidden_states
|
||||
logits_eager = outputs_eager[key]
|
||||
logits_sdpa = outputs_sdpa[key]
|
||||
|
||||
if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]:
|
||||
logits_eager = logits_eager[-1]
|
||||
logits_sdpa = logits_sdpa[-1]
|
||||
|
||||
if key == "logits_per_text":
|
||||
nan_mask = torch.isnan(logits_eager)
|
||||
logits_eager[nan_mask] = 0
|
||||
logits_sdpa[nan_mask] = 0
|
||||
|
||||
if torch_device in ["cpu", "cuda"]:
|
||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||
elif torch_device == "hpu":
|
||||
atol = atols["cuda", enable_kernels, torch_dtype]
|
||||
rtol = rtols["cuda", enable_kernels, torch_dtype]
|
||||
elif torch_device == "xpu":
|
||||
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||
# which is implemented on PyTorch level using aten operators and is
|
||||
# device agnostic with respect to implementation of each aten operator.
|
||||
atol = atols["cuda", False, torch_dtype]
|
||||
rtol = rtols["cuda", False, torch_dtype]
|
||||
else:
|
||||
atol = 1e-7
|
||||
rtol = 1e-4
|
||||
|
||||
# Masked tokens output slightly deviates - we don't mind that.
|
||||
if use_attention_mask:
|
||||
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
|
||||
_logits_eager = torch.zeros_like(input=logits_eager)
|
||||
|
||||
_logits_sdpa[:-1] = logits_sdpa[:-1]
|
||||
_logits_eager[:-1] = logits_eager[:-1]
|
||||
|
||||
if padding_side == "left":
|
||||
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
|
||||
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
|
||||
|
||||
elif padding_side == "right":
|
||||
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
|
||||
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
|
||||
|
||||
logits_sdpa = _logits_sdpa
|
||||
logits_eager = _logits_eager
|
||||
|
||||
results = [
|
||||
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
|
||||
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
|
||||
]
|
||||
# If 80% batch elements have matched results, it's fine
|
||||
if np.mean(results) < 0.8:
|
||||
mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean()
|
||||
raise ValueError(
|
||||
f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = "
|
||||
f"{rtol}"
|
||||
)
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
for key in configs_no_init.__dict__.keys():
|
||||
@ -3410,321 +3738,9 @@ class ModelTesterMixin:
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
# TODO: we shouldn't need to do this skip, i.e. the test would be composable from the model tester. CLIP-like
|
||||
# models have a custom mixin, which we detect to skip this test.
|
||||
if any(".CLIPModelTesterMixin" in str(base) for base in self.__class__.__bases__):
|
||||
self.skipTest(reason="CLIP-like models have a different `test_eager_matches_sdpa_inference`")
|
||||
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not self.all_model_classes[0]._supports_sdpa:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
# convert shorthand name to torch.dtype
|
||||
if torch_dtype == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
elif torch_dtype == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
|
||||
if not is_torch_fp16_available_on_device(torch_device) and torch_dtype == torch.float16:
|
||||
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
|
||||
|
||||
if not is_torch_bf16_available_on_device(torch_device) and torch_dtype == torch.bfloat16:
|
||||
self.skipTest(
|
||||
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
|
||||
)
|
||||
|
||||
# Dictionary of tolerances for eager <> sdpa tests. Key = (device, sdpa_kernels_enabled, dtype)
|
||||
atols = {
|
||||
("cpu", False, torch.float32): 1e-6,
|
||||
("cpu", False, torch.float16): 5e-3,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-6,
|
||||
("cpu", True, torch.float16): 5e-3,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-6,
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 5e-3,
|
||||
("cuda", True, torch.float32): 1e-6,
|
||||
("cuda", True, torch.bfloat16): 1e-2,
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
rtols = {
|
||||
("cpu", False, torch.float32): 1e-4,
|
||||
("cpu", False, torch.float16): 5e-3,
|
||||
("cpu", False, torch.bfloat16): 1e-2,
|
||||
("cpu", True, torch.float32): 1e-4,
|
||||
("cpu", True, torch.float16): 5e-3,
|
||||
("cpu", True, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float32): 1e-4,
|
||||
("cuda", False, torch.bfloat16): 1e-2,
|
||||
("cuda", False, torch.float16): 5e-3,
|
||||
("cuda", True, torch.float32): 1e-4,
|
||||
("cuda", True, torch.bfloat16): 3e-2, # (different from others)
|
||||
("cuda", True, torch.float16): 5e-3,
|
||||
}
|
||||
|
||||
set_model_tester_for_less_flaky_test(self)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
set_config_for_less_flaky_test(config)
|
||||
model = model_class(config)
|
||||
# TODO: standardize the interfaces for musicgen models, see other todo in this test
|
||||
if model.__class__.__name__ == "MusicgenMelodyForConditionalGeneration":
|
||||
is_encoder_decoder = True
|
||||
else:
|
||||
is_encoder_decoder = model.config.is_encoder_decoder
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_from_pretrained_kwargs = {
|
||||
"pretrained_model_name_or_path": tmpdirname,
|
||||
"torch_dtype": torch_dtype,
|
||||
}
|
||||
|
||||
if (
|
||||
hasattr(config, "use_mask_token")
|
||||
or "use_mask_token" in inspect.signature(model.__init__).parameters
|
||||
):
|
||||
model_from_pretrained_kwargs["use_mask_token"] = True
|
||||
|
||||
# TODO: remove this try/except, models should have a shared API
|
||||
try:
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
**model_from_pretrained_kwargs, attn_implementation="sdpa"
|
||||
)
|
||||
except ValueError:
|
||||
model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
|
||||
|
||||
model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
|
||||
|
||||
set_model_for_less_flaky_test(model_eager)
|
||||
set_model_for_less_flaky_test(model_sdpa)
|
||||
|
||||
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
|
||||
if not (self.has_attentions and can_output_attn) and output_attentions:
|
||||
self.skipTest(reason="Model does not support output_attentions")
|
||||
|
||||
# TODO: if we can also check with `batch_size=1` without being flaky?
|
||||
for batch_size in [7]:
|
||||
# musicgen decoder models; TODO: find better abstraction
|
||||
if (
|
||||
model.__class__.__name__.startswith("Musicgen")
|
||||
and hasattr(self.model_tester, "num_codebooks")
|
||||
and not hasattr(model_eager, "text_encoder")
|
||||
):
|
||||
input_data_batch_size = batch_size * self.model_tester.num_codebooks
|
||||
else:
|
||||
input_data_batch_size = batch_size
|
||||
|
||||
processed_inputs = {}
|
||||
processed_inputs[model.main_input_name] = inputs_dict[model.main_input_name]
|
||||
|
||||
for key in getattr(self, "additional_model_inputs", []):
|
||||
# Some models don't have all `additional_model_inputs`, especially when we
|
||||
# craft cases to test model in different settings
|
||||
if key in inputs_dict:
|
||||
processed_inputs[key] = inputs_dict[key]
|
||||
|
||||
for key, value in processed_inputs.items():
|
||||
if torch.is_floating_point(value):
|
||||
value = value.to(torch_dtype)
|
||||
|
||||
# extend value to have at least `input_data_batch_size` elements
|
||||
if value.shape[0] < input_data_batch_size:
|
||||
size = (input_data_batch_size - value.shape[0], *value.shape[1:])
|
||||
if torch.is_floating_point(value):
|
||||
extension = torch.rand(size=size, dtype=value.dtype, device=torch_device)
|
||||
else:
|
||||
extension = torch.randint(high=5, size=size, dtype=value.dtype, device=torch_device)
|
||||
value = torch.cat((value, extension), dim=0).to(torch_device)
|
||||
|
||||
processed_inputs[key] = value[:input_data_batch_size]
|
||||
|
||||
if not use_attention_mask:
|
||||
dummy_attention_mask = None
|
||||
else:
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
if dummy_attention_mask is None:
|
||||
if is_encoder_decoder:
|
||||
seqlen = inputs_dict.get(
|
||||
"decoder_input_ids", processed_inputs[model.main_input_name]
|
||||
).shape[-1]
|
||||
else:
|
||||
seqlen = processed_inputs[model.main_input_name].shape[-1]
|
||||
dummy_attention_mask = torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device)
|
||||
|
||||
# extend dummy_attention_mask to have at least `batch_size` elements
|
||||
if dummy_attention_mask.shape[0] < batch_size:
|
||||
size = (batch_size - dummy_attention_mask.shape[0], *dummy_attention_mask.shape[1:])
|
||||
extension = torch.ones(size=size, dtype=dummy_attention_mask.dtype, device=torch_device)
|
||||
dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0)
|
||||
|
||||
dummy_attention_mask = dummy_attention_mask[:batch_size].to(torch_device)
|
||||
|
||||
dummy_attention_mask[:] = 1
|
||||
if padding_side == "left":
|
||||
dummy_attention_mask[-1, :2] = 0
|
||||
dummy_attention_mask[-1, 2:] = 1
|
||||
elif padding_side == "right":
|
||||
dummy_attention_mask[-1, -2:] = 0
|
||||
dummy_attention_mask[-1, :-2] = 1
|
||||
|
||||
if is_encoder_decoder:
|
||||
# musicgen encoder-decoder models; TODO: find better abstraction
|
||||
if model.__class__.__name__.startswith("Musicgen") and hasattr(self.model_tester, "num_codebooks"):
|
||||
input_data_batch_size = batch_size * self.model_tester.num_codebooks
|
||||
else:
|
||||
input_data_batch_size = batch_size
|
||||
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", processed_inputs[model.main_input_name])
|
||||
decoder_input_ids = decoder_input_ids[:input_data_batch_size]
|
||||
if decoder_input_ids.shape[0] != input_data_batch_size:
|
||||
extension = torch.ones(
|
||||
input_data_batch_size - decoder_input_ids.shape[0],
|
||||
*decoder_input_ids.shape[1:],
|
||||
dtype=decoder_input_ids.dtype,
|
||||
device=torch_device,
|
||||
)
|
||||
decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0)
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
# TODO: never an `attention_mask` arg here?
|
||||
processed_inputs.update(
|
||||
{
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
)
|
||||
else:
|
||||
processed_inputs.update(
|
||||
{
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Otherwise fails for e.g. WhisperEncoderModel
|
||||
if "attention_mask" in inspect.signature(model_eager.forward).parameters:
|
||||
processed_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
if self.has_attentions and "output_attentions" in inspect.signature(model_sdpa.forward).parameters:
|
||||
processed_inputs["output_attentions"] = output_attentions
|
||||
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
|
||||
dummy_mask = torch.ones((self.model_tester.num_masks,))
|
||||
|
||||
# In case of additional token (like class) we define a custom `mask_length`
|
||||
if hasattr(self.model_tester, "mask_length"):
|
||||
mask_length = self.model_tester.mask_length - dummy_mask.size(0)
|
||||
else:
|
||||
mask_length = self.model_tester.seq_length - dummy_mask.size(0)
|
||||
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
|
||||
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
|
||||
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
|
||||
|
||||
if "noise" in inspect.signature(model_eager.forward).parameters:
|
||||
np.random.seed(2)
|
||||
num_patches = int((self.model_tester.image_size // self.model_tester.patch_size) ** 2)
|
||||
noise = np.random.uniform(size=(batch_size, num_patches))
|
||||
processed_inputs["noise"] = torch.from_numpy(noise)
|
||||
|
||||
# TODO: test gradients as well (& for FA2 as well!)
|
||||
with torch.no_grad():
|
||||
with sdpa_kernel(
|
||||
enable_flash=enable_kernels,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=enable_kernels,
|
||||
):
|
||||
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
|
||||
prepared_inputs = {
|
||||
k: v.to(torch_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in prepared_inputs.items()
|
||||
}
|
||||
outputs_eager = model_eager(**prepared_inputs)
|
||||
outputs_sdpa = model_sdpa(**prepared_inputs)
|
||||
|
||||
if "logits_per_text" in outputs_eager:
|
||||
key = "logits_per_text"
|
||||
elif "vision_hidden_states" in outputs_eager:
|
||||
key = "vision_hidden_states"
|
||||
elif "audio_values" in outputs_eager:
|
||||
key = "audio_values"
|
||||
elif "decoder_hidden_states" in outputs_eager:
|
||||
key = "decoder_hidden_states"
|
||||
elif "logits" in outputs_eager and "Classification" in model_class.__name__:
|
||||
key = "logits"
|
||||
elif "language_model_outputs" in outputs_eager and "blip" in model_class.__name__.lower():
|
||||
outputs_eager = outputs_eager["language_model_outputs"]
|
||||
outputs_sdpa = outputs_sdpa["language_model_outputs"]
|
||||
key = "hidden_states" if "hidden_states" in outputs_eager else "decoder_hidden_states"
|
||||
else:
|
||||
key = "hidden_states"
|
||||
|
||||
# TODO: rename logits -> hidden_states
|
||||
logits_eager = outputs_eager[key]
|
||||
logits_sdpa = outputs_sdpa[key]
|
||||
|
||||
if key in ["vision_hidden_states", "decoder_hidden_states", "hidden_states"]:
|
||||
logits_eager = logits_eager[-1]
|
||||
logits_sdpa = logits_sdpa[-1]
|
||||
|
||||
if key == "logits_per_text":
|
||||
nan_mask = torch.isnan(logits_eager)
|
||||
logits_eager[nan_mask] = 0
|
||||
logits_sdpa[nan_mask] = 0
|
||||
|
||||
if torch_device in ["cpu", "cuda"]:
|
||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||
elif torch_device == "hpu":
|
||||
atol = atols["cuda", enable_kernels, torch_dtype]
|
||||
rtol = rtols["cuda", enable_kernels, torch_dtype]
|
||||
elif torch_device == "xpu":
|
||||
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||
# which is implemented on PyTorch level using aten operators and is
|
||||
# device agnostic with respect to implementation of each aten operator.
|
||||
atol = atols["cuda", False, torch_dtype]
|
||||
rtol = rtols["cuda", False, torch_dtype]
|
||||
else:
|
||||
atol = 1e-7
|
||||
rtol = 1e-4
|
||||
|
||||
# Masked tokens output slightly deviates - we don't mind that.
|
||||
if use_attention_mask:
|
||||
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
|
||||
_logits_eager = torch.zeros_like(input=logits_eager)
|
||||
|
||||
_logits_sdpa[:-1] = logits_sdpa[:-1]
|
||||
_logits_eager[:-1] = logits_eager[:-1]
|
||||
|
||||
if padding_side == "left":
|
||||
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
|
||||
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
|
||||
|
||||
elif padding_side == "right":
|
||||
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
|
||||
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
|
||||
|
||||
logits_sdpa = _logits_sdpa
|
||||
logits_eager = _logits_eager
|
||||
|
||||
results = [
|
||||
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
|
||||
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
|
||||
]
|
||||
# If 80% batch elements have matched results, it's fine
|
||||
if np.mean(results) < 0.8:
|
||||
mean_relative_diff = ((logits_sdpa - logits_eager).abs() / (logits_eager.abs() + 1e-12)).mean()
|
||||
raise ValueError(
|
||||
f"mean relative difference for {key}: {mean_relative_diff:.3e}, torch atol = {atol}, torch rtol = "
|
||||
f"{rtol}"
|
||||
)
|
||||
_test_eager_matches_sdpa_inference(
|
||||
self, name, torch_dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
)
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_torch_accelerator
|
||||
|
@ -276,6 +276,9 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"attention_chunk_size",
|
||||
],
|
||||
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
||||
# position_embedding_type not used and deprecated. Should be deleted in v4.55
|
||||
"LayoutLMConfig": ["position_embedding_type"],
|
||||
"MarkupLMConfig": ["position_embedding_type"],
|
||||
"SmolLM3Config": ["no_rope_layer_interval"],
|
||||
"Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm`
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user