mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Merge branch 'main' into hz-fix-int-decompression
This commit is contained in:
commit
58bae8dd36
@ -433,6 +433,8 @@
|
||||
title: DiffLlama
|
||||
- local: model_doc/distilbert
|
||||
title: DistilBERT
|
||||
- local: model_doc/dots1
|
||||
title: dots1
|
||||
- local: model_doc/dpr
|
||||
title: DPR
|
||||
- local: model_doc/electra
|
||||
@ -655,6 +657,8 @@
|
||||
title: SwitchTransformers
|
||||
- local: model_doc/t5
|
||||
title: T5
|
||||
- local: model_doc/t5gemma
|
||||
title: T5Gemma
|
||||
- local: model_doc/t5v1.1
|
||||
title: T5v1.1
|
||||
- local: model_doc/tapex
|
||||
@ -843,7 +847,7 @@
|
||||
title: GraniteSpeech
|
||||
- local: model_doc/hubert
|
||||
title: Hubert
|
||||
- local: model_doc/stt
|
||||
- local: model_doc/kyutai_speech_to_text
|
||||
title: Kyutai Speech-To-Text
|
||||
- local: model_doc/mctct
|
||||
title: MCTCT
|
||||
@ -955,6 +959,8 @@
|
||||
title: Gemma3
|
||||
- local: model_doc/git
|
||||
title: GIT
|
||||
- local: model_doc/glm4v
|
||||
title: glm4v
|
||||
- local: model_doc/got_ocr2
|
||||
title: GOT-OCR2
|
||||
- local: model_doc/granitevision
|
||||
@ -1047,6 +1053,8 @@
|
||||
title: SigLIP
|
||||
- local: model_doc/siglip2
|
||||
title: SigLIP2
|
||||
- local: model_doc/smollm3
|
||||
title: SmolLM3
|
||||
- local: model_doc/smolvlm
|
||||
title: SmolVLM
|
||||
- local: model_doc/speech-encoder-decoder
|
||||
|
40
docs/source/en/model_doc/dots1.md
Normal file
40
docs/source/en/model_doc/dots1.md
Normal file
@ -0,0 +1,40 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# dots.llm1
|
||||
|
||||
## Overview
|
||||
|
||||
The `dots.llm1` model was proposed in [dots.llm1 technical report](https://www.arxiv.org/pdf/2506.05767) by rednote-hilab team.
|
||||
|
||||
The abstract from the report is the following:
|
||||
|
||||
*Mixture of Experts (MoE) models have emerged as a promising paradigm for scaling language models efficiently by activating only a subset of parameters for each input token. In this report, we present dots.llm1, a large-scale MoE model that activates 14B parameters out of a total of 142B parameters, delivering performance on par with state-of-the-art models while reducing training and inference costs. Leveraging our meticulously crafted and efficient data processing pipeline, dots.llm1 achieves performance comparable to Qwen2.5-72B after pretraining on high-quality corpus and post-training to fully unlock its capabilities. Notably, no synthetic data is used during pretraining. To foster further research, we open-source intermediate training checkpoints spanning the entire training process, providing valuable insights into the learning dynamics of large language models.*
|
||||
|
||||
|
||||
## Dots1Config
|
||||
|
||||
[[autodoc]] Dots1Config
|
||||
|
||||
## Dots1Model
|
||||
|
||||
[[autodoc]] Dots1Model
|
||||
- forward
|
||||
|
||||
## Dots1ForCausalLM
|
||||
|
||||
[[autodoc]] Dots1ForCausalLM
|
||||
- forward
|
180
docs/source/en/model_doc/glm4v.md
Normal file
180
docs/source/en/model_doc/glm4v.md
Normal file
@ -0,0 +1,180 @@
|
||||
<!--Copyright 2025 The ZhipuAI Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white"> </div>
|
||||
</div>
|
||||
|
||||
# GLM-4.1V
|
||||
|
||||
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
pipe = pipeline(
|
||||
task="image-text-to-text",
|
||||
model="THUDM/GLM-4.1V-9B-Thinking",
|
||||
device=0,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
|
||||
},
|
||||
{ "type": "text", "text": "Describe this image."},
|
||||
]
|
||||
}
|
||||
]
|
||||
pipe(text=messages,max_new_tokens=20, return_full_text=False)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import Glm4vForConditionalGeneration, AutoProcessor
|
||||
|
||||
model = Glm4vForConditionalGeneration.from_pretrained(
|
||||
"THUDM/GLM-4.1V-9B-Thinking",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
|
||||
messages = [
|
||||
{
|
||||
"role":"user",
|
||||
"content":[
|
||||
{
|
||||
"type":"image",
|
||||
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
},
|
||||
{
|
||||
"type":"text",
|
||||
"text":"Describe this image."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
]
|
||||
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Using GLM-4.1V with video input is similar to using it with image input.
|
||||
The model can process video data and generate text based on the content of the video.
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, Glm4vForConditionalGeneration
|
||||
import torch
|
||||
|
||||
processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
|
||||
model = Glm4vForConditionalGeneration.from_pretrained(
|
||||
pretrained_model_name_or_path="THUDM/GLM-4.1V-9B-Thinking",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda:0"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "discribe this video",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True).to("cuda:0")
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=1.0)
|
||||
output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
print(output_text)
|
||||
```
|
||||
|
||||
## Glm4vConfig
|
||||
|
||||
[[autodoc]] Glm4vConfig
|
||||
|
||||
## Glm4vTextConfig
|
||||
|
||||
[[autodoc]] Glm4vTextConfig
|
||||
|
||||
## Glm4vImageProcessor
|
||||
|
||||
[[autodoc]] Glm4vImageProcessor
|
||||
- preprocess
|
||||
|
||||
## Glm4vVideoProcessor
|
||||
|
||||
[[autodoc]] Glm4vVideoProcessor
|
||||
- preprocess
|
||||
|
||||
## Glm4vImageProcessorFast
|
||||
|
||||
[[autodoc]] Glm4vImageProcessorFast
|
||||
- preprocess
|
||||
|
||||
## Glm4vProcessor
|
||||
|
||||
[[autodoc]] Glm4vProcessor
|
||||
|
||||
## Glm4vTextModel
|
||||
|
||||
[[autodoc]] Glm4vTextModel
|
||||
- forward
|
||||
|
||||
## Glm4vModel
|
||||
|
||||
[[autodoc]] Glm4vModel
|
||||
- forward
|
||||
|
||||
## Glm4vForConditionalGeneration
|
||||
|
||||
[[autodoc]] Glm4vForConditionalGeneration
|
||||
- forward
|
@ -36,10 +36,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
|
||||
|
||||
# 1. load the model and the processor
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_id = "kyutai/stt-2.6b-en"
|
||||
model_id = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
|
||||
|
||||
# 2. load audio samples
|
||||
ds = load_dataset(
|
||||
@ -69,10 +69,10 @@ from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForCondi
|
||||
|
||||
# 1. load the model and the processor
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_id = "kyutai/stt-2.6b-en"
|
||||
model_id = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device, torch_dtype="auto")
|
||||
|
||||
# 2. load audio samples
|
||||
ds = load_dataset(
|
@ -56,7 +56,7 @@ Here is how to use the processor to process text and audio:
|
||||
```python
|
||||
>>> # let's load an audio sample from an Arabic speech corpus
|
||||
>>> from datasets import load_dataset
|
||||
>>> dataset = load_dataset("arabic_speech_corpus", split="test", streaming=True, trust_remote_code=True)
|
||||
>>> dataset = load_dataset("halabi2016/arabic_speech_corpus", split="test", streaming=True)
|
||||
>>> audio_sample = next(iter(dataset))["audio"]
|
||||
|
||||
>>> # now, process it
|
||||
|
@ -56,7 +56,7 @@ Here is how to use the processor to process text and audio:
|
||||
```python
|
||||
>>> # let's load an audio sample from an Arabic speech corpus
|
||||
>>> from datasets import load_dataset
|
||||
>>> dataset = load_dataset("arabic_speech_corpus", split="test", streaming=True, trust_remote_code=True)
|
||||
>>> dataset = load_dataset("halabi2016/arabic_speech_corpus", split="test", streaming=True)
|
||||
>>> audio_sample = next(iter(dataset))["audio"]
|
||||
|
||||
>>> # now, process it
|
||||
|
173
docs/source/en/model_doc/smollm3.md
Normal file
173
docs/source/en/model_doc/smollm3.md
Normal file
@ -0,0 +1,173 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# SmolLM3
|
||||
|
||||
SmolLM3 is a fully open, compact language model designed for efficient deployment while maintaining strong performance. It uses a Transformer decoder architecture with Grouped Query Attention (GQA) to reduce the kv cache, and no RoPE, enabling improved performance on long-context tasks. It is trained using a multi-stage training approach on high-quality public datasets across web, code, and math domains. The model is multilingual and supports very large context lengths. The instruct variant is optimized for reasoning and tool use.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the SmolLM3 models in the right sidebar for more examples of how to apply SmolLM3 to different language tasks.
|
||||
|
||||
The example below demonstrates how to generate text with [`Pipeline`], [`AutoModel`], and from the command line using the instruction-tuned models.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(
|
||||
task="text-generation",
|
||||
model="HuggingFaceTB/SmolLM3-3B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=0
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Tell me about yourself."},
|
||||
]
|
||||
outputs = pipe(messages, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
|
||||
print(outputs[0]["generated_text"][-1]['content'])
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM3-3B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
|
||||
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to("cuda")
|
||||
|
||||
generated_ids = model.generate(
|
||||
model_inputs.input_ids,
|
||||
cache_implementation="static",
|
||||
max_new_tokens=512,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.95
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
print(response)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```bash
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
transformers chat HuggingFaceTB/SmolLM3-3B --torch_dtype auto --attn_implementation flash_attention_2 --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to 4-bits.
|
||||
|
||||
```python
|
||||
# pip install -U flash-attn --no-build-isolation
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM3-3B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
inputs = tokenizer("Gravity is the force", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Ensure your Transformers library version is up-to-date. SmolLM3 requires Transformers>=4.53.0 for full support.
|
||||
|
||||
## SmolLM3Config
|
||||
|
||||
[[autodoc]] SmolLM3Config
|
||||
|
||||
## SmolLM3Model
|
||||
|
||||
[[autodoc]] SmolLM3Model
|
||||
- forward
|
||||
|
||||
## SmolLM3ForCausalLM
|
||||
|
||||
[[autodoc]] SmolLM3ForCausalLM
|
||||
- forward
|
||||
|
||||
## SmolLM3ForSequenceClassification
|
||||
|
||||
[[autodoc]] SmolLM3ForSequenceClassification
|
||||
- forward
|
||||
|
||||
## SmolLM3ForTokenClassification
|
||||
|
||||
[[autodoc]] SmolLM3ForTokenClassification
|
||||
- forward
|
||||
|
||||
## SmolLM3ForQuestionAnswering
|
||||
|
||||
[[autodoc]] SmolLM3ForQuestionAnswering
|
||||
- forward
|
107
docs/source/en/model_doc/t5gemma.md
Normal file
107
docs/source/en/model_doc/t5gemma.md
Normal file
@ -0,0 +1,107 @@
|
||||
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
|
||||
# T5Gemma
|
||||
|
||||
T5Gemma (aka encoder-decoder Gemma) was proposed in a [research paper](https://arxiv.org/abs/2504.06225) by Google. It is a family of encoder-decoder large langauge models, developed by adapting pretrained decoder-only models into encoder-decoder. T5Gemma includes pretrained and instruction-tuned variants. The architecture is based on transformer encoder-decoder design following T5, with improvements from Gemma 2: GQA, RoPE, GeGLU activation, RMSNorm, and interleaved local/global attention.
|
||||
|
||||
T5Gemma has two groups of model sizes: 1) [Gemma 2](https://ai.google.dev/gemma/docs/core/model_card_2) sizes (2B-2B, 9B-2B, and 9B-9B), which are based on the offical Gemma 2 models (2B and 9B); and 2) [T5](https://arxiv.org/abs/1910.10683) sizes (Small, Base, Large, and XL), where are pretrained under the Gemma 2 framework following T5 configuration. In addition, we also provide a model at ML size (medium large, ~2B in total), which is in-between T5 Large and T5 XL.
|
||||
|
||||
The pretrained varaints are trained with two objectives: prefix language modeling with knowledge distillation (PrefixLM) and UL2, separately. We release both variants for each model size. The instruction-turned varaints was post-trained with supervised fine-tuning and reinforcement learning.
|
||||
|
||||
The example below demonstrates how to chat with the model with [`Pipeline`] or the [`AutoModel`] class, and from the command line.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline(
|
||||
task="text2text-generation",
|
||||
model="google/t5gemma-placeholder",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
pipe("Question: Why is the sky blue?\nAnswer:", max_new_tokens=50)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-placeholder")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/t5gemma-placeholder",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
input_text = "Question: Why is the sky blue?\nAnswer:"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=32)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="transformers CLI">
|
||||
|
||||
```
|
||||
echo -e "Question: Why is the sky blue? Answer:" | transformers run --task text2text-generation --model google/t5gemma-placeholder --device 0
|
||||
```
|
||||
|
||||
## T5GemmaConfig
|
||||
|
||||
[[autodoc]] T5GemmaConfig
|
||||
|
||||
## T5GemmaModuleConfig
|
||||
|
||||
[[autodoc]] T5GemmaModuleConfig
|
||||
|
||||
## T5GemmaModel
|
||||
|
||||
[[autodoc]] T5GemmaModel
|
||||
- forward
|
||||
|
||||
## T5GemmaEncoderModel
|
||||
|
||||
[[autodoc]] T5GemmaEncoderModel
|
||||
- forward
|
||||
|
||||
## T5GemmaForConditionalGeneration
|
||||
|
||||
[[autodoc]] T5GemmaForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## T5GemmaForSequenceClassification
|
||||
|
||||
[[autodoc]] T5GemmaForSequenceClassification
|
||||
- forward
|
||||
|
||||
## T5GemmaForTokenClassification
|
||||
|
||||
[[autodoc]] T5GemmaForTokenClassification
|
||||
- forward
|
@ -18,7 +18,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
Transformers provides many pretrained models that are ready to use with a single line of code. It requires a model class and the [`~PreTrainedModel.from_pretrained`] method.
|
||||
|
||||
Call [`~PreTrainedModel.from_pretrained`] to download and load a models weights and configuration stored on the Hugging Face [Hub](https://hf.co/models).
|
||||
Call [`~PreTrainedModel.from_pretrained`] to download and load a model's weights and configuration stored on the Hugging Face [Hub](https://hf.co/models).
|
||||
|
||||
> [!TIP]
|
||||
> The [`~PreTrainedModel.from_pretrained`] method loads weights stored in the [safetensors](https://hf.co/docs/safetensors/index) file format if they're available. Traditionally, PyTorch model weights are serialized with the [pickle](https://docs.python.org/3/library/pickle.html) utility which is known to be unsecure. Safetensor files are more secure and faster to load.
|
||||
|
@ -264,7 +264,6 @@ class ExamplesTests(TestCasePlus):
|
||||
--dataset_config clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--trust_remote_code
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--num_train_epochs=2
|
||||
|
@ -312,7 +312,6 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||
{self.examples_dir}/pytorch/image-classification/run_image_classification_no_trainer.py
|
||||
--model_name_or_path google/vit-base-patch16-224-in21k
|
||||
--dataset_name hf-internal-testing/cats_vs_dogs_sample
|
||||
--trust_remote_code
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 1
|
||||
|
@ -17,7 +17,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining
|
||||
@ -391,7 +390,6 @@ class ExamplesTests(TestCasePlus):
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path google/vit-base-patch16-224-in21k
|
||||
--dataset_name hf-internal-testing/cats_vs_dogs_sample
|
||||
--trust_remote_code
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
@ -415,7 +413,6 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
|
||||
|
||||
@unittest.skip("temporary to avoid failing on circleci")
|
||||
def test_run_speech_recognition_ctc(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -426,7 +423,6 @@ class ExamplesTests(TestCasePlus):
|
||||
--dataset_config_name clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--trust_remote_code
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
@ -447,7 +443,6 @@ class ExamplesTests(TestCasePlus):
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
@unittest.skip("temporary to avoid failing on circleci")
|
||||
def test_run_speech_recognition_ctc_adapter(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -458,7 +453,6 @@ class ExamplesTests(TestCasePlus):
|
||||
--dataset_config_name clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--trust_remote_code
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
@ -481,7 +475,6 @@ class ExamplesTests(TestCasePlus):
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "./adapter.tur.safetensors")))
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
@unittest.skip("temporary to avoid failing on circleci")
|
||||
def test_run_speech_recognition_seq2seq(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
@ -492,7 +485,6 @@ class ExamplesTests(TestCasePlus):
|
||||
--dataset_config_name clean
|
||||
--train_split_name validation
|
||||
--eval_split_name validation
|
||||
--trust_remote_code
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
@ -520,7 +512,6 @@ class ExamplesTests(TestCasePlus):
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
|
||||
--dataset_name anton-l/superb_demo
|
||||
--trust_remote_code
|
||||
--dataset_config_name ks
|
||||
--train_split_name test
|
||||
--eval_split_name test
|
||||
@ -555,7 +546,6 @@ class ExamplesTests(TestCasePlus):
|
||||
--dataset_name hf-internal-testing/librispeech_asr_dummy
|
||||
--dataset_config_names clean
|
||||
--dataset_split_names validation
|
||||
--trust_remote_code
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 4
|
||||
--per_device_eval_batch_size 4
|
||||
@ -576,7 +566,6 @@ class ExamplesTests(TestCasePlus):
|
||||
run_mae.py
|
||||
--output_dir {tmp_dir}
|
||||
--dataset_name hf-internal-testing/cats_vs_dogs_sample
|
||||
--trust_remote_code
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
|
@ -315,7 +315,6 @@ class ExamplesTests(TestCasePlus):
|
||||
testargs = f"""
|
||||
run_image_classification.py
|
||||
--dataset_name hf-internal-testing/cats_vs_dogs_sample
|
||||
--trust_remote_code
|
||||
--model_name_or_path microsoft/resnet-18
|
||||
--do_train
|
||||
--do_eval
|
||||
|
@ -52,6 +52,7 @@ line-ending = "auto"
|
||||
addopts = "--doctest-glob='**/*.md'"
|
||||
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
|
||||
markers = [
|
||||
"flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
|
||||
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
|
||||
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
|
||||
"generate: marks tests that use the GenerationTesterMixin"
|
||||
|
@ -32,6 +32,13 @@ def flash_attention_forward(
|
||||
# This is before the transpose
|
||||
seq_len = query.shape[2]
|
||||
|
||||
if any(dim == 0 for dim in query.shape):
|
||||
raise ValueError(
|
||||
"Tensor query has shape with a zero dimension.\n"
|
||||
"FlashAttention does not support inputs with dim=0.\n"
|
||||
"Please check your input shapes or use SDPA instead."
|
||||
)
|
||||
|
||||
# FA2 uses non-transposed inputs
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
@ -68,6 +75,7 @@ def flash_attention_forward(
|
||||
softcap=softcap,
|
||||
use_top_left_mask=_use_top_left_mask,
|
||||
target_dtype=target_dtype,
|
||||
attn_implementation=module.config._attn_implementation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import torch
|
||||
@ -21,6 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from .utils import (
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torch_npu_available,
|
||||
@ -32,18 +34,123 @@ logger = logging.get_logger(__name__)
|
||||
flash_attn_func = None
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.layers.rotary import apply_rotary_emb # noqa
|
||||
def _index_first_axis(tensor, indices):
|
||||
"""
|
||||
A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
|
||||
after flattening the first two dimensions of the tensor. This is functionally equivalent to
|
||||
FA2's `index_first_axis` and replaces the need to import it.
|
||||
"""
|
||||
# The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
|
||||
# two dimensions to get (total_tokens, ...) before indexing.
|
||||
reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
|
||||
return reshaped_tensor[indices]
|
||||
|
||||
|
||||
def _fa3_unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||
"""
|
||||
FA3-compatible unpad_input function.
|
||||
|
||||
Arguments:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
||||
Return:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
||||
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
||||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||||
max_seqlen_in_batch: int
|
||||
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
|
||||
"""
|
||||
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
||||
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
||||
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
|
||||
return (
|
||||
_index_first_axis(hidden_states, indices),
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
used_seqlens_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def _fa3_pad_input(hidden_states, indices, batch, seqlen):
|
||||
"""
|
||||
FA3-compatible pad_input function.
|
||||
|
||||
Arguments:
|
||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
||||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
||||
batch: int, batch size for the padded sequence.
|
||||
seqlen: int, maximum sequence length for the padded sequence.
|
||||
Return:
|
||||
hidden_states: (batch, seqlen, ...)
|
||||
"""
|
||||
dim = hidden_states.shape[1:]
|
||||
output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
output[indices] = hidden_states
|
||||
return output.view(batch, seqlen, *dim)
|
||||
|
||||
|
||||
FA_VERSION = None
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_func as flash_attn_2_func
|
||||
from flash_attn import flash_attn_varlen_func as flash_attn_2_varlen_func
|
||||
from flash_attn.bert_padding import pad_input as pad_input_fa2
|
||||
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
|
||||
from flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
HAS_FA2 = True
|
||||
FA_VERSION = 2
|
||||
else:
|
||||
flash_attn_2_func = None
|
||||
flash_attn_2_varlen_func = None
|
||||
pad_input_fa2 = None
|
||||
unpad_input_fa2 = None
|
||||
apply_rotary_emb = None
|
||||
HAS_FA2 = False
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
|
||||
pad_input_fa3 = _fa3_pad_input
|
||||
unpad_input_fa3 = _fa3_unpad_input
|
||||
HAS_FA3 = True
|
||||
FA_VERSION = 3
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
pad_input_fa3 = None
|
||||
unpad_input_fa3 = None
|
||||
HAS_FA3 = False
|
||||
|
||||
|
||||
# Current Flash Attention implementations
|
||||
if FA_VERSION:
|
||||
flash_attn_func = globals()[f"flash_attn_{FA_VERSION}_func"]
|
||||
flash_attn_varlen_func = globals()[f"flash_attn_{FA_VERSION}_varlen_func"]
|
||||
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
|
||||
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]
|
||||
|
||||
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
|
||||
if is_torch_npu_available():
|
||||
from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
|
||||
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
|
||||
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
|
||||
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
|
||||
from .integrations.npu_flash_attention import (
|
||||
npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401
|
||||
)
|
||||
from .integrations.npu_flash_attention import (
|
||||
npu_flash_attn_func as flash_attn_func,
|
||||
)
|
||||
from .integrations.npu_flash_attention import (
|
||||
npu_flash_attn_varlen_func as flash_attn_varlen_func,
|
||||
)
|
||||
from .integrations.npu_flash_attention import (
|
||||
pad_input,
|
||||
unpad_input,
|
||||
)
|
||||
|
||||
|
||||
_flash_supports_window_size = False
|
||||
@ -56,6 +163,9 @@ if flash_attn_func:
|
||||
def is_flash_attn_available():
|
||||
"""Determine whether flash-attention can be used or not."""
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
return True
|
||||
|
||||
# if package `flash-attn` is available, flash-attention can be used natively.
|
||||
if is_flash_attn_2_available():
|
||||
return True
|
||||
@ -70,6 +180,9 @@ def is_flash_attn_available():
|
||||
def flash_attn_supports_top_left_mask():
|
||||
"""Determine whether flash-attention uses top-left or down-right mask"""
|
||||
|
||||
if is_flash_attn_3_available():
|
||||
return False
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
# top-left mask is used in package `flash-attn` with version lower than 2.1.0
|
||||
return not is_flash_attn_greater_or_equal_2_10()
|
||||
@ -116,6 +229,7 @@ def _upad_input(
|
||||
value_layer: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
unpad_input_func,
|
||||
):
|
||||
"""
|
||||
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
|
||||
@ -134,6 +248,8 @@ def _upad_input(
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
query_length (`int`):
|
||||
Target length.
|
||||
unpad_input_func:
|
||||
The function to use for unpadding the input tensors.
|
||||
|
||||
Return:
|
||||
query_layer (`torch.Tensor`):
|
||||
@ -158,12 +274,10 @@ def _upad_input(
|
||||
|
||||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
||||
|
||||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
|
||||
value_layer = index_first_axis(
|
||||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
||||
)
|
||||
key_layer = _index_first_axis(key_layer, indices_k)
|
||||
value_layer = _index_first_axis(value_layer, indices_k)
|
||||
if query_length == kv_seq_len:
|
||||
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
|
||||
query_layer = _index_first_axis(query_layer, indices_k)
|
||||
cu_seqlens_q = cu_seqlens_k
|
||||
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
||||
indices_q = indices_k
|
||||
@ -177,7 +291,7 @@ def _upad_input(
|
||||
else:
|
||||
# The -q_len: slice assumes left padding.
|
||||
attention_mask = attention_mask[:, -query_length:]
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask)
|
||||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
|
||||
|
||||
return (
|
||||
query_layer,
|
||||
@ -189,7 +303,7 @@ def _upad_input(
|
||||
)
|
||||
|
||||
|
||||
def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||
def _prepare_flash_attention_from_position_ids(query, key, value, position_ids):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
@ -239,6 +353,14 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||
|
||||
|
||||
def prepare_fa2_from_position_ids(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"The function `prepare_fa2_from_position_ids` in `transformers.modeling_flash_attention_utils` is deprecated and will be removed in a future version. Please use `_prepare_flash_attention_from_position_ids` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return _prepare_flash_attention_from_position_ids(*args, **kwargs)
|
||||
|
||||
|
||||
def fa_peft_integration_check(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@ -303,6 +425,7 @@ def _flash_attention_forward(
|
||||
max_length_q: Optional[int] = None,
|
||||
max_length_k: Optional[int] = None,
|
||||
target_dtype: Optional[torch.dtype] = None,
|
||||
attn_implementation: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -329,7 +452,28 @@ def _flash_attention_forward(
|
||||
Softcap for the attention logits, used e.g. in gemma2.
|
||||
deterministic (`bool`, *optional*):
|
||||
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
|
||||
attn_implementation (`str`, *optional*):
|
||||
The attention implementation to use. If None, will default to the one based on the environment.
|
||||
"""
|
||||
if attn_implementation is None:
|
||||
_flash_attn_varlen_func = flash_attn_varlen_func
|
||||
_flash_attn_func = flash_attn_func
|
||||
_pad_input = pad_input
|
||||
_unpad_input = unpad_input
|
||||
_is_fa3 = HAS_FA3
|
||||
elif attn_implementation == "flash_attention_3":
|
||||
_flash_attn_varlen_func = flash_attn_3_varlen_func
|
||||
_flash_attn_func = flash_attn_3_func
|
||||
_pad_input = pad_input_fa3
|
||||
_unpad_input = unpad_input_fa3
|
||||
_is_fa3 = True
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
_flash_attn_varlen_func = flash_attn_2_varlen_func
|
||||
_flash_attn_func = flash_attn_2_func
|
||||
_pad_input = pad_input_fa2
|
||||
_unpad_input = unpad_input_fa2
|
||||
_is_fa3 = False
|
||||
|
||||
if not use_top_left_mask:
|
||||
causal = is_causal
|
||||
else:
|
||||
@ -342,6 +486,12 @@ def _flash_attention_forward(
|
||||
)
|
||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
||||
|
||||
if _is_fa3:
|
||||
if dropout > 0.0:
|
||||
logger.warning_once("Flash Attention 3 does not support dropout. Setting dropout to 0.0.")
|
||||
else:
|
||||
flash_kwargs["dropout_p"] = dropout
|
||||
|
||||
if flash_241:
|
||||
if deterministic is None:
|
||||
global deterministic_g
|
||||
@ -362,12 +512,12 @@ def _flash_attention_forward(
|
||||
if attention_mask is not None:
|
||||
batch_size = query_states.shape[0]
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
|
||||
query_states, key_states, value_states, attention_mask, query_length
|
||||
query_states, key_states, value_states, attention_mask, query_length, _unpad_input
|
||||
)
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
attn_output_unpad = _flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@ -375,24 +525,25 @@ def _flash_attention_forward(
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||
|
||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
elif position_ids is not None and (
|
||||
max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||
elif (
|
||||
position_ids is not None
|
||||
and query_states.shape[0] == 1
|
||||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
|
||||
):
|
||||
batch_size = query_states.size(0)
|
||||
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
|
||||
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
|
||||
_prepare_flash_attention_from_position_ids(query_states, key_states, value_states, position_ids)
|
||||
)
|
||||
|
||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||
@ -403,7 +554,7 @@ def _flash_attention_forward(
|
||||
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
|
||||
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_output = _flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@ -411,7 +562,6 @@ def _flash_attention_forward(
|
||||
cu_seqlens_k=cu_seq_lens_k,
|
||||
max_seqlen_q=max_length_q,
|
||||
max_seqlen_k=max_length_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
**flash_kwargs,
|
||||
@ -420,10 +570,12 @@ def _flash_attention_forward(
|
||||
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
||||
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||
attn_output = _flash_attn_func(
|
||||
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
|
||||
)
|
||||
|
||||
if isinstance(attn_output, tuple):
|
||||
return attn_output[0]
|
||||
return attn_output
|
||||
|
||||
|
||||
|
@ -105,6 +105,7 @@ from .utils import (
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_3_available,
|
||||
is_kernels_available,
|
||||
is_offline_mode,
|
||||
is_optimum_available,
|
||||
@ -1957,6 +1958,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# Flash Attention 2 support
|
||||
_supports_flash_attn_2 = False
|
||||
|
||||
# Flash Attention 3 support
|
||||
_supports_flash_attn_3 = False
|
||||
|
||||
# SDPA support
|
||||
_supports_sdpa = False
|
||||
|
||||
@ -2247,6 +2251,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
and config._attn_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys()
|
||||
):
|
||||
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
|
||||
if cls._supports_flash_attn_3:
|
||||
message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
|
||||
if cls._supports_flash_attn_2:
|
||||
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
|
||||
if cls._supports_sdpa:
|
||||
@ -2282,7 +2288,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
):
|
||||
sub_config._attn_implementation_internal = curr_attn_implementation
|
||||
|
||||
if config._attn_implementation == "flash_attention_2":
|
||||
if config._attn_implementation == "flash_attention_3":
|
||||
cls._check_and_enable_flash_attn_3(
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hard_check_only=False,
|
||||
check_device_map=check_device_map,
|
||||
)
|
||||
elif config._attn_implementation == "flash_attention_2":
|
||||
cls._check_and_enable_flash_attn_2(
|
||||
config,
|
||||
torch_dtype=torch_dtype,
|
||||
@ -2498,6 +2512,94 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
config._attn_implementation = "flash_attention_2"
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_flash_attn_3(
|
||||
cls,
|
||||
config,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, dict[str, int]]] = None,
|
||||
check_device_map: bool = True,
|
||||
hard_check_only: bool = False,
|
||||
) -> PretrainedConfig:
|
||||
"""
|
||||
Checks the availability of Flash Attention 3 and compatibility with the current model.
|
||||
|
||||
If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
|
||||
"""
|
||||
if not cls._supports_flash_attn_3:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
|
||||
f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
|
||||
" or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
|
||||
)
|
||||
|
||||
if not is_flash_attn_3_available():
|
||||
preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
|
||||
|
||||
if importlib.util.find_spec("flash_attn_3") is None:
|
||||
raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major < 9:
|
||||
raise ValueError(
|
||||
f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0."
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"{preface} Flash Attention 3 is not available.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device."
|
||||
)
|
||||
|
||||
if torch_dtype is None:
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour"
|
||||
)
|
||||
elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
|
||||
logger.warning_once(
|
||||
"Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
|
||||
f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
|
||||
' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
|
||||
)
|
||||
|
||||
if getattr(config, "alibi", False) or getattr(config, "use_alibi", False):
|
||||
raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
|
||||
|
||||
# Check for attention dropout, which is incompatible with FA3
|
||||
if hasattr(config, "attention_dropout") and config.attention_dropout > 0:
|
||||
raise ValueError(
|
||||
f"Model has attention_dropout={config.attention_dropout}, which is not supported by Flash Attention 3."
|
||||
)
|
||||
|
||||
# The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
|
||||
# or the model may be initialized under the context manager `with torch.device("cuda"):`.
|
||||
if check_device_map and device_map is None and torch.empty(0).device.type not in ["cuda", "mlu"]:
|
||||
if torch.cuda.is_available():
|
||||
logger.warning_once(
|
||||
"You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
|
||||
" after initializing it on CPU with `model.to('cuda')`."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
|
||||
"This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
|
||||
"or initialising the model on CPU and then moving it to GPU."
|
||||
)
|
||||
elif (
|
||||
check_device_map
|
||||
and device_map is not None
|
||||
and isinstance(device_map, dict)
|
||||
and ("cpu" in device_map.values() or "disk" in device_map.values())
|
||||
):
|
||||
raise ValueError(
|
||||
"You are attempting to use Flash Attention 3 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
|
||||
"initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
|
||||
)
|
||||
if not hard_check_only:
|
||||
config._attn_implementation = "flash_attention_3"
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
|
||||
"""
|
||||
@ -4134,7 +4236,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
|
||||
</Tip>
|
||||
attn_implementation (`str`, *optional*):
|
||||
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
|
||||
The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
|
||||
|
||||
> Parameters for big model inference
|
||||
|
||||
@ -5770,6 +5872,7 @@ class AttentionInterface(GeneralInterface):
|
||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||
# a new instance is created (in order to locally override a given function)
|
||||
_global_mapping = {
|
||||
"flash_attention_3": flash_attention_forward,
|
||||
"flash_attention_2": flash_attention_forward,
|
||||
"flex_attention": flex_attention_forward,
|
||||
"paged_attention": paged_attention_forward,
|
||||
|
@ -96,6 +96,7 @@ if TYPE_CHECKING:
|
||||
from .distilbert import *
|
||||
from .dit import *
|
||||
from .donut import *
|
||||
from .dots1 import *
|
||||
from .dpr import *
|
||||
from .dpt import *
|
||||
from .efficientnet import *
|
||||
@ -157,6 +158,7 @@ if TYPE_CHECKING:
|
||||
from .janus import *
|
||||
from .jetmoe import *
|
||||
from .kosmos2 import *
|
||||
from .kyutai_speech_to_text import *
|
||||
from .layoutlm import *
|
||||
from .layoutlmv2 import *
|
||||
from .layoutlmv3 import *
|
||||
@ -285,7 +287,6 @@ if TYPE_CHECKING:
|
||||
from .squeezebert import *
|
||||
from .stablelm import *
|
||||
from .starcoder2 import *
|
||||
from .stt import *
|
||||
from .superglue import *
|
||||
from .superpoint import *
|
||||
from .swiftformer import *
|
||||
@ -294,6 +295,7 @@ if TYPE_CHECKING:
|
||||
from .swinv2 import *
|
||||
from .switch_transformers import *
|
||||
from .t5 import *
|
||||
from .t5gemma import *
|
||||
from .table_transformer import *
|
||||
from .tapas import *
|
||||
from .textnet import *
|
||||
|
@ -321,6 +321,7 @@ class ArceePreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ArceeDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -667,6 +667,7 @@ class AriaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["AriaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -206,7 +206,7 @@ def convert_audio_spectrogram_transformer_checkpoint(model_name, pytorch_dump_fo
|
||||
|
||||
if "speech-commands" in model_name:
|
||||
# TODO: Convert dataset to Parquet
|
||||
dataset = load_dataset("google/speech_commands", "v0.02", split="validation", trust_remote_code=True)
|
||||
dataset = load_dataset("google/speech_commands", "v0.02", split="validation")
|
||||
waveform = dataset[0]["audio"]["array"]
|
||||
else:
|
||||
filepath = hf_hub_download(
|
||||
|
@ -112,6 +112,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
|
||||
("distilbert", "DistilBertConfig"),
|
||||
("donut-swin", "DonutSwinConfig"),
|
||||
("dots1", "Dots1Config"),
|
||||
("dpr", "DPRConfig"),
|
||||
("dpt", "DPTConfig"),
|
||||
("efficientformer", "EfficientFormerConfig"),
|
||||
@ -141,6 +142,8 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("git", "GitConfig"),
|
||||
("glm", "GlmConfig"),
|
||||
("glm4", "Glm4Config"),
|
||||
("glm4v", "Glm4vConfig"),
|
||||
("glm4v_text", "Glm4vTextConfig"),
|
||||
("glpn", "GLPNConfig"),
|
||||
("got_ocr2", "GotOcr2Config"),
|
||||
("gpt-sw3", "GPT2Config"),
|
||||
@ -181,6 +184,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("jetmoe", "JetMoeConfig"),
|
||||
("jukebox", "JukeboxConfig"),
|
||||
("kosmos-2", "Kosmos2Config"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextConfig"),
|
||||
("layoutlm", "LayoutLMConfig"),
|
||||
("layoutlmv2", "LayoutLMv2Config"),
|
||||
("layoutlmv3", "LayoutLMv3Config"),
|
||||
@ -312,6 +316,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("siglip", "SiglipConfig"),
|
||||
("siglip2", "Siglip2Config"),
|
||||
("siglip_vision_model", "SiglipVisionConfig"),
|
||||
("smollm3", "SmolLM3Config"),
|
||||
("smolvlm", "SmolVLMConfig"),
|
||||
("smolvlm_vision", "SmolVLMVisionConfig"),
|
||||
("speech-encoder-decoder", "SpeechEncoderDecoderConfig"),
|
||||
@ -322,7 +327,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("squeezebert", "SqueezeBertConfig"),
|
||||
("stablelm", "StableLmConfig"),
|
||||
("starcoder2", "Starcoder2Config"),
|
||||
("stt", "KyutaiSpeechToTextConfig"),
|
||||
("superglue", "SuperGlueConfig"),
|
||||
("superpoint", "SuperPointConfig"),
|
||||
("swiftformer", "SwiftFormerConfig"),
|
||||
@ -331,6 +335,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("swinv2", "Swinv2Config"),
|
||||
("switch_transformers", "SwitchTransformersConfig"),
|
||||
("t5", "T5Config"),
|
||||
("t5gemma", "T5GemmaConfig"),
|
||||
("table-transformer", "TableTransformerConfig"),
|
||||
("tapas", "TapasConfig"),
|
||||
("textnet", "TextNetConfig"),
|
||||
@ -481,6 +486,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("distilbert", "DistilBERT"),
|
||||
("dit", "DiT"),
|
||||
("donut-swin", "DonutSwin"),
|
||||
("dots1", "dots1"),
|
||||
("dpr", "DPR"),
|
||||
("dpt", "DPT"),
|
||||
("efficientformer", "EfficientFormer"),
|
||||
@ -512,7 +518,9 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("gemma3_text", "Gemma3ForCausalLM"),
|
||||
("git", "GIT"),
|
||||
("glm", "GLM"),
|
||||
("glm4", "glm4"),
|
||||
("glm4", "GLM4"),
|
||||
("glm4v", "GLM4V"),
|
||||
("glm4v_text", "GLM4V"),
|
||||
("glpn", "GLPN"),
|
||||
("got_ocr2", "GOT-OCR2"),
|
||||
("gpt-sw3", "GPT-Sw3"),
|
||||
@ -554,6 +562,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("jetmoe", "JetMoe"),
|
||||
("jukebox", "Jukebox"),
|
||||
("kosmos-2", "KOSMOS-2"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToText"),
|
||||
("layoutlm", "LayoutLM"),
|
||||
("layoutlmv2", "LayoutLMv2"),
|
||||
("layoutlmv3", "LayoutLMv3"),
|
||||
@ -698,6 +707,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("siglip2", "SigLIP2"),
|
||||
("siglip2_vision_model", "Siglip2VisionModel"),
|
||||
("siglip_vision_model", "SiglipVisionModel"),
|
||||
("smollm3", "SmolLM3"),
|
||||
("smolvlm", "SmolVLM"),
|
||||
("smolvlm_vision", "SmolVLMVisionTransformer"),
|
||||
("speech-encoder-decoder", "Speech Encoder decoder"),
|
||||
@ -708,7 +718,6 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("squeezebert", "SqueezeBERT"),
|
||||
("stablelm", "StableLm"),
|
||||
("starcoder2", "Starcoder2"),
|
||||
("stt", "KyutaiSpeechToText"),
|
||||
("superglue", "SuperGlue"),
|
||||
("superpoint", "SuperPoint"),
|
||||
("swiftformer", "SwiftFormer"),
|
||||
@ -717,6 +726,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("swinv2", "Swin Transformer V2"),
|
||||
("switch_transformers", "SwitchTransformers"),
|
||||
("t5", "T5"),
|
||||
("t5gemma", "T5Gemma"),
|
||||
("t5v1.1", "T5v1.1"),
|
||||
("table-transformer", "Table Transformer"),
|
||||
("tapas", "TAPAS"),
|
||||
@ -827,6 +837,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
||||
("clip_text_model", "clip"),
|
||||
("aria_text", "aria"),
|
||||
("gemma3_text", "gemma3"),
|
||||
("glm4v_text", "glm4v"),
|
||||
("idefics3_vision", "idefics3"),
|
||||
("siglip_vision_model", "siglip"),
|
||||
("smolvlm_vision", "smolvlm"),
|
||||
|
@ -65,6 +65,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("groupvit", "CLIPFeatureExtractor"),
|
||||
("hubert", "Wav2Vec2FeatureExtractor"),
|
||||
("imagegpt", "ImageGPTFeatureExtractor"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
|
||||
("layoutlmv2", "LayoutLMv2FeatureExtractor"),
|
||||
("layoutlmv3", "LayoutLMv3FeatureExtractor"),
|
||||
("levit", "LevitFeatureExtractor"),
|
||||
@ -91,7 +92,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("sew-d", "Wav2Vec2FeatureExtractor"),
|
||||
("speech_to_text", "Speech2TextFeatureExtractor"),
|
||||
("speecht5", "SpeechT5FeatureExtractor"),
|
||||
("stt", "KyutaiSpeechToTextFeatureExtractor"),
|
||||
("swiftformer", "ViTFeatureExtractor"),
|
||||
("swin", "ViTFeatureExtractor"),
|
||||
("swinv2", "ViTFeatureExtractor"),
|
||||
|
@ -89,6 +89,7 @@ else:
|
||||
("fuyu", ("FuyuImageProcessor",)),
|
||||
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
|
||||
("git", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("glm4v", ("Glm4vImageProcessor", "Glm4vImageProcessorFast")),
|
||||
("glpn", ("GLPNImageProcessor",)),
|
||||
("got_ocr2", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
|
||||
("grounding-dino", ("GroundingDinoImageProcessor", "GroundingDinoImageProcessorFast")),
|
||||
|
@ -105,6 +105,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("dinov2_with_registers", "Dinov2WithRegistersModel"),
|
||||
("distilbert", "DistilBertModel"),
|
||||
("donut-swin", "DonutSwinModel"),
|
||||
("dots1", "Dots1Model"),
|
||||
("dpr", "DPRQuestionEncoder"),
|
||||
("dpt", "DPTModel"),
|
||||
("efficientformer", "EfficientFormerModel"),
|
||||
@ -133,6 +134,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("git", "GitModel"),
|
||||
("glm", "GlmModel"),
|
||||
("glm4", "Glm4Model"),
|
||||
("glm4v", "Glm4vModel"),
|
||||
("glm4v_text", "Glm4vTextModel"),
|
||||
("glpn", "GLPNModel"),
|
||||
("got_ocr2", "GotOcr2Model"),
|
||||
("gpt-sw3", "GPT2Model"),
|
||||
@ -171,6 +174,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("jetmoe", "JetMoeModel"),
|
||||
("jukebox", "JukeboxModel"),
|
||||
("kosmos-2", "Kosmos2Model"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextModel"),
|
||||
("layoutlm", "LayoutLMModel"),
|
||||
("layoutlmv2", "LayoutLMv2Model"),
|
||||
("layoutlmv3", "LayoutLMv3Model"),
|
||||
@ -292,6 +296,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("siglip", "SiglipModel"),
|
||||
("siglip2", "Siglip2Model"),
|
||||
("siglip_vision_model", "SiglipVisionModel"),
|
||||
("smollm3", "SmolLM3Model"),
|
||||
("smolvlm", "SmolVLMModel"),
|
||||
("smolvlm_vision", "SmolVLMVisionTransformer"),
|
||||
("speech_to_text", "Speech2TextModel"),
|
||||
@ -300,7 +305,6 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("squeezebert", "SqueezeBertModel"),
|
||||
("stablelm", "StableLmModel"),
|
||||
("starcoder2", "Starcoder2Model"),
|
||||
("stt", "KyutaiSpeechToTextModel"),
|
||||
("superglue", "SuperGlueForKeypointMatching"),
|
||||
("swiftformer", "SwiftFormerModel"),
|
||||
("swin", "SwinModel"),
|
||||
@ -308,6 +312,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("swinv2", "Swinv2Model"),
|
||||
("switch_transformers", "SwitchTransformersModel"),
|
||||
("t5", "T5Model"),
|
||||
("t5gemma", "T5GemmaModel"),
|
||||
("table-transformer", "TableTransformerModel"),
|
||||
("tapas", "TapasModel"),
|
||||
("textnet", "TextNetModel"),
|
||||
@ -428,6 +433,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("t5gemma", "T5GemmaForConditionalGeneration"),
|
||||
("tapas", "TapasForMaskedLM"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("tvlt", "TvltForPreTraining"),
|
||||
@ -522,6 +528,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("t5gemma", "T5GemmaForConditionalGeneration"),
|
||||
("tapas", "TapasForMaskedLM"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("wav2vec2", "Wav2Vec2ForMaskedLM"),
|
||||
@ -562,6 +569,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("dbrx", "DbrxForCausalLM"),
|
||||
("deepseek_v3", "DeepseekV3ForCausalLM"),
|
||||
("diffllama", "DiffLlamaForCausalLM"),
|
||||
("dots1", "Dots1ForCausalLM"),
|
||||
("electra", "ElectraForCausalLM"),
|
||||
("emu3", "Emu3ForCausalLM"),
|
||||
("ernie", "ErnieForCausalLM"),
|
||||
@ -637,6 +645,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("roc_bert", "RoCBertForCausalLM"),
|
||||
("roformer", "RoFormerForCausalLM"),
|
||||
("rwkv", "RwkvForCausalLM"),
|
||||
("smollm3", "SmolLM3ForCausalLM"),
|
||||
("speech_to_text_2", "Speech2Text2ForCausalLM"),
|
||||
("stablelm", "StableLmForCausalLM"),
|
||||
("starcoder2", "Starcoder2ForCausalLM"),
|
||||
@ -896,6 +905,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
|
||||
("fuyu", "FuyuForCausalLM"),
|
||||
("gemma3", "Gemma3ForConditionalGeneration"),
|
||||
("git", "GitForCausalLM"),
|
||||
("glm4v", "Glm4vForConditionalGeneration"),
|
||||
("got_ocr2", "GotOcr2ForConditionalGeneration"),
|
||||
("idefics", "IdeficsForVisionText2Text"),
|
||||
("idefics2", "Idefics2ForConditionalGeneration"),
|
||||
@ -1041,6 +1051,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
|
||||
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
|
||||
("t5", "T5ForConditionalGeneration"),
|
||||
("t5gemma", "T5GemmaForConditionalGeneration"),
|
||||
("umt5", "UMT5ForConditionalGeneration"),
|
||||
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
|
||||
]
|
||||
@ -1049,6 +1060,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("granite_speech", "GraniteSpeechForConditionalGeneration"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
||||
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
||||
@ -1056,7 +1068,6 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
|
||||
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
||||
("speecht5", "SpeechT5ForSpeechToText"),
|
||||
("stt", "KyutaiSpeechToTextForConditionalGeneration"),
|
||||
("whisper", "WhisperForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
@ -1149,10 +1160,12 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
|
||||
("roc_bert", "RoCBertForSequenceClassification"),
|
||||
("roformer", "RoFormerForSequenceClassification"),
|
||||
("smollm3", "SmolLM3ForSequenceClassification"),
|
||||
("squeezebert", "SqueezeBertForSequenceClassification"),
|
||||
("stablelm", "StableLmForSequenceClassification"),
|
||||
("starcoder2", "Starcoder2ForSequenceClassification"),
|
||||
("t5", "T5ForSequenceClassification"),
|
||||
("t5gemma", "T5GemmaForSequenceClassification"),
|
||||
("tapas", "TapasForSequenceClassification"),
|
||||
("transfo-xl", "TransfoXLForSequenceClassification"),
|
||||
("umt5", "UMT5ForSequenceClassification"),
|
||||
@ -1234,6 +1247,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
|
||||
("roc_bert", "RoCBertForQuestionAnswering"),
|
||||
("roformer", "RoFormerForQuestionAnswering"),
|
||||
("smollm3", "SmolLM3ForQuestionAnswering"),
|
||||
("splinter", "SplinterForQuestionAnswering"),
|
||||
("squeezebert", "SqueezeBertForQuestionAnswering"),
|
||||
("t5", "T5ForQuestionAnswering"),
|
||||
@ -1342,10 +1356,12 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
|
||||
("roc_bert", "RoCBertForTokenClassification"),
|
||||
("roformer", "RoFormerForTokenClassification"),
|
||||
("smollm3", "SmolLM3ForTokenClassification"),
|
||||
("squeezebert", "SqueezeBertForTokenClassification"),
|
||||
("stablelm", "StableLmForTokenClassification"),
|
||||
("starcoder2", "Starcoder2ForTokenClassification"),
|
||||
("t5", "T5ForTokenClassification"),
|
||||
("t5gemma", "T5GemmaForTokenClassification"),
|
||||
("umt5", "UMT5ForTokenClassification"),
|
||||
("xlm", "XLMForTokenClassification"),
|
||||
("xlm-roberta", "XLMRobertaForTokenClassification"),
|
||||
@ -1579,6 +1595,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
||||
("roformer", "RoFormerModel"),
|
||||
("squeezebert", "SqueezeBertModel"),
|
||||
("t5", "T5EncoderModel"),
|
||||
("t5gemma", "T5GemmaEncoderModel"),
|
||||
("umt5", "UMT5EncoderModel"),
|
||||
("xlm", "XLMModel"),
|
||||
("xlm-roberta", "XLMRobertaModel"),
|
||||
|
@ -66,6 +66,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("fuyu", "FuyuProcessor"),
|
||||
("gemma3", "Gemma3Processor"),
|
||||
("git", "GitProcessor"),
|
||||
("glm4v", "Glm4vProcessor"),
|
||||
("got_ocr2", "GotOcr2Processor"),
|
||||
("granite_speech", "GraniteSpeechProcessor"),
|
||||
("grounding-dino", "GroundingDinoProcessor"),
|
||||
@ -79,6 +80,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("internvl", "InternVLProcessor"),
|
||||
("janus", "JanusProcessor"),
|
||||
("kosmos-2", "Kosmos2Processor"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextProcessor"),
|
||||
("layoutlmv2", "LayoutLMv2Processor"),
|
||||
("layoutlmv3", "LayoutLMv3Processor"),
|
||||
("llama4", "Llama4Processor"),
|
||||
@ -116,7 +118,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("speech_to_text", "Speech2TextProcessor"),
|
||||
("speech_to_text_2", "Speech2Text2Processor"),
|
||||
("speecht5", "SpeechT5Processor"),
|
||||
("stt", "KyutaiSpeechToTextProcessor"),
|
||||
("trocr", "TrOCRProcessor"),
|
||||
("tvlt", "TvltProcessor"),
|
||||
("tvp", "TvpProcessor"),
|
||||
|
@ -238,6 +238,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
@ -581,6 +582,13 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"T5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"t5gemma",
|
||||
(
|
||||
"GemmaTokenizer" if is_sentencepiece_available() else None,
|
||||
"GemmaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("tapas", ("TapasTokenizer", None)),
|
||||
("tapex", ("TapexTokenizer", None)),
|
||||
("transfo-xl", ("TransfoXLTokenizer", None)),
|
||||
|
@ -46,6 +46,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("glm4v", "Glm4vVideoProcessor"),
|
||||
("instructblip", "InstructBlipVideoVideoProcessor"),
|
||||
("instructblipvideo", "InstructBlipVideoVideoProcessor"),
|
||||
("internvl", "InternVLVideoProcessor"),
|
||||
|
@ -282,18 +282,6 @@ BARK_ATTENTION_CLASSES = {
|
||||
}
|
||||
|
||||
|
||||
class BarkLayerNorm(nn.Module):
|
||||
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
|
||||
|
||||
def __init__(self, hidden_size, bias=True):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
|
||||
|
||||
def forward(self, input):
|
||||
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, eps=1e-5)
|
||||
|
||||
|
||||
class BarkMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@ -315,11 +303,10 @@ class BarkBlock(GradientCheckpointingLayer):
|
||||
super().__init__()
|
||||
|
||||
if is_causal:
|
||||
# if causal, uses handmade LayerNorm, so that the layerNorm bias is optional
|
||||
# this handmade layerNorm is used to stick with Bark choice of leaving optional bias in
|
||||
# AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
|
||||
self.layernorm_1 = BarkLayerNorm(config.hidden_size, bias=config.bias)
|
||||
self.layernorm_2 = BarkLayerNorm(config.hidden_size, bias=config.bias)
|
||||
# if causal, the layerNorm bias is optional to stick with Bark choice of leaving optional bias
|
||||
# in AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
|
||||
self.layernorm_1 = nn.LayerNorm(config.hidden_size, bias=config.bias)
|
||||
self.layernorm_2 = nn.LayerNorm(config.hidden_size, bias=config.bias)
|
||||
else:
|
||||
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
|
||||
self.layernorm_2 = nn.LayerNorm(config.hidden_size)
|
||||
@ -427,7 +414,7 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
|
||||
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
|
||||
self.layernorm_final = nn.LayerNorm(config.hidden_size, bias=config.bias)
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -266,7 +266,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
# Check outputs on an image
|
||||
if is_semantic:
|
||||
image_processor = BeitImageProcessor(size=config.image_size, do_center_crop=False)
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test", trust_remote_code=True)
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
|
||||
image = Image.open(ds[0]["file"])
|
||||
else:
|
||||
image_processor = BeitImageProcessor(
|
||||
|
@ -318,6 +318,7 @@ class BitNetPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BitNetDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -355,6 +355,7 @@ class CoherePreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["CohereDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -334,6 +334,7 @@ class Cohere2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Cohere2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -226,7 +226,7 @@ def convert_wav2vec2_checkpoint(
|
||||
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-lv60")
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
input_audio = [x["array"] for x in ds[:4]["audio"]]
|
||||
|
||||
inputs = processor(input_audio, return_tensors="pt", padding=True)
|
||||
|
@ -504,6 +504,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DeepseekV3DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -556,6 +556,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DiffLlamaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = False
|
||||
|
27
src/transformers/models/dots1/__init__.py
Normal file
27
src/transformers/models/dots1/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_dots1 import *
|
||||
from .modeling_dots1 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
211
src/transformers/models/dots1/configuration_dots1.py
Normal file
211
src/transformers/models/dots1/configuration_dots1.py
Normal file
@ -0,0 +1,211 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Dots1Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Dots1Model`]. It is used to instantiate a
|
||||
`dots.llm1` model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of
|
||||
[rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 152064):
|
||||
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
|
||||
`input_ids` passed when calling [`Dots1Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 4608):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 10944):
|
||||
Dimension of the MLP representations.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1408):
|
||||
Dimension of the MoE representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 62):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
||||
Number of key/value heads for Grouped Query Attention. If `num_key_value_heads=num_attention_heads`, Multi
|
||||
Head Attention (MHA) is used. If `num_key_value_heads=1`, Multi Query Attention (MQA) is used. Otherwise,
|
||||
Grouped Query Attention (GQA) is used. If not specified, defaults to `num_attention_heads`.
|
||||
n_shared_experts (`int`, *optional*, default=None):
|
||||
Number of shared experts. None means dense model.
|
||||
n_routed_experts (`int`, *optional*, default=None):
|
||||
Number of routed experts. None means dense model.
|
||||
n_group (`int`, *optional*, defaults to 1):
|
||||
Number of groups for routed experts.
|
||||
topk_group (`int`, *optional*, defaults to 1):
|
||||
Number of selected groups for each token (selected experts only within `topk_group` groups).
|
||||
num_experts_per_tok (`int`, *optional*, default=None):
|
||||
Number of selected experts. None means dense model.
|
||||
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
||||
Number of dense layers at the beginning of the model before the first MoE layer.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the weights of the routed experts.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string).
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
Maximum sequence length the model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
Standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
Epsilon used by the RMS normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions. Only relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie the input and output word embeddings.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`dict`, *optional*):
|
||||
Dictionary for scaling RoPE embeddings. Supports `{"type": strategy name, "factor": scaling factor}`.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the self-attention projections.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout ratio for the attention probabilities.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for routed experts.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Size of the sliding window for attention. If not specified, defaults to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 62):
|
||||
The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
|
||||
additional layer afterwards will use SWA (Sliding Window Attention).
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from transformers import Dots1Model, Dots1Config
|
||||
|
||||
>>> # Initializing a Dots1 style configuration
|
||||
>>> configuration = Dots1Config()
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "dots1"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.experts.*.gate_proj": "local_colwise",
|
||||
"layers.*.mlp.experts.*.up_proj": "local_colwise",
|
||||
"layers.*.mlp.experts.*.down_proj": "local_rowwise",
|
||||
"layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
|
||||
"layers.*.mlp.shared_experts.gate_proj": "local_colwise",
|
||||
"layers.*.mlp.shared_experts.up_proj": "local_colwise",
|
||||
"layers.*.mlp.shared_experts.down_proj": "local_rowwise",
|
||||
"layers.*.mlp.shared_experts": "local",
|
||||
"layers.*.mlp.gate_proj": "local_colwise",
|
||||
"layers.*.mlp.up_proj": "local_colwise",
|
||||
"layers.*.mlp.down_proj": "local_rowwise",
|
||||
"layers.*.mlp": "gather", # This is the only moment where results are gathered
|
||||
}
|
||||
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=4608,
|
||||
intermediate_size=10944,
|
||||
moe_intermediate_size=1408,
|
||||
num_hidden_layers=62,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
n_shared_experts=None,
|
||||
n_routed_experts=None,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
num_experts_per_tok=None,
|
||||
first_k_dense_replace=0,
|
||||
norm_topk_prob=False,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
routed_scaling_factor=1.0,
|
||||
sliding_window=4096,
|
||||
max_window_layers=62,
|
||||
layer_types=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention"
|
||||
if self.sliding_window is not None and i >= self.max_window_layers
|
||||
else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Dots1Config"]
|
700
src/transformers/models/dots1/modeling_dots1.py
Normal file
700
src/transformers/models/dots1/modeling_dots1.py
Normal file
@ -0,0 +1,700 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/dots1/modular_dots1.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_dots1.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_dots1 import Dots1Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class Dots1RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Dots1RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Dots1RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Dots1Config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Dots1Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Dots1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
self.q_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
||||
self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
|
||||
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window, # diff with Llama
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Dots1MLP(nn.Module):
|
||||
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Dots1MoE(nn.Module):
|
||||
"""
|
||||
A mixed expert module containing shared experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList(
|
||||
[Dots1MLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_routed_experts)]
|
||||
)
|
||||
self.gate = Dots1TopkRouter(config)
|
||||
self.shared_experts = Dots1MLP(
|
||||
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
|
||||
)
|
||||
|
||||
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||
r"""
|
||||
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
|
||||
to not have to do a loop here (deepseek has 256 experts soooo yeah).
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
|
||||
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
|
||||
expert_mask = expert_mask.permute(2, 0, 1)
|
||||
|
||||
for expert_idx in range(len(self.experts)):
|
||||
expert = self.experts[expert_idx]
|
||||
mask = expert_mask[expert_idx]
|
||||
token_indices, weight_indices = torch.where(mask)
|
||||
|
||||
if token_indices.numel() > 0:
|
||||
expert_weights = topk_weights[token_indices, weight_indices]
|
||||
expert_input = hidden_states[token_indices]
|
||||
expert_output = expert(expert_input)
|
||||
weighted_output = expert_output * expert_weights.unsqueeze(-1)
|
||||
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
||||
|
||||
# in original deepseek, the output of the experts are gathered once we leave this module
|
||||
# thus the moe module is itelsf an IsolatedParallel module
|
||||
# and all expert are "local" meaning we shard but we don't gather
|
||||
return final_hidden_states.type(hidden_states.dtype)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_indices, topk_weights = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Dots1TopkRouter(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||
self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_topk_indices(self, scores):
|
||||
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||
return topk_indices
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||
scores = router_logits.sigmoid()
|
||||
topk_indices = self.get_topk_indices(scores)
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor
|
||||
return topk_indices, topk_weights
|
||||
|
||||
|
||||
class Dots1DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Dots1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Dots1Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
if layer_idx >= config.first_k_dense_replace:
|
||||
self.mlp = Dots1MoE(config)
|
||||
else:
|
||||
self.mlp = Dots1MLP(config)
|
||||
|
||||
self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Dots1PreTrainedModel(PreTrainedModel):
|
||||
config_class = Dots1Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Dots1DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Dots1RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Dots1TopkRouter):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Dots1Model(Dots1PreTrainedModel):
|
||||
def __init__(self, config: Dots1Config):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[Dots1DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Dots1RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
}
|
||||
# The sliding window alternating layers are not always activated depending on the config
|
||||
if self.has_sliding_layers:
|
||||
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Dots1Model(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Dots1ForCausalLM
|
||||
|
||||
>>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Dots1PreTrainedModel", "Dots1Model", "Dots1ForCausalLM"]
|
111
src/transformers/models/dots1/modular_dots1.py
Normal file
111
src/transformers/models/dots1/modular_dots1.py
Normal file
@ -0,0 +1,111 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ...modeling_outputs import CausalLMOutputWithPast
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ..deepseek_v3.modeling_deepseek_v3 import (
|
||||
DeepseekV3DecoderLayer,
|
||||
DeepseekV3MLP,
|
||||
DeepseekV3MoE,
|
||||
DeepseekV3PreTrainedModel,
|
||||
DeepseekV3TopkRouter,
|
||||
)
|
||||
from ..qwen3.modeling_qwen3 import (
|
||||
KwargsForCausalLM,
|
||||
Qwen3Attention,
|
||||
Qwen3ForCausalLM,
|
||||
Qwen3Model,
|
||||
Qwen3RMSNorm,
|
||||
Qwen3RotaryEmbedding,
|
||||
)
|
||||
from .configuration_dots1 import Dots1Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Dots1RMSNorm(Qwen3RMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class Dots1RotaryEmbedding(Qwen3RotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class Dots1Attention(Qwen3Attention):
|
||||
pass
|
||||
|
||||
|
||||
class Dots1MLP(DeepseekV3MLP):
|
||||
pass
|
||||
|
||||
|
||||
class Dots1MoE(DeepseekV3MoE):
|
||||
pass
|
||||
|
||||
|
||||
class Dots1TopkRouter(DeepseekV3TopkRouter):
|
||||
pass
|
||||
|
||||
|
||||
class Dots1DecoderLayer(DeepseekV3DecoderLayer):
|
||||
def __init__(self, config: Dots1Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
|
||||
class Dots1PreTrainedModel(DeepseekV3PreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class Dots1Model(Qwen3Model):
|
||||
pass
|
||||
|
||||
|
||||
class Dots1ForCausalLM(Qwen3ForCausalLM):
|
||||
def forward(
|
||||
self,
|
||||
**super_kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Dots1ForCausalLM
|
||||
|
||||
>>> model = Dots1ForCausalLM.from_pretrained("rednote-hilab/dots1.llm1.inst")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("rednote-hilab/dots1.llm1.inst")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
return super().forward(**super_kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Dots1PreTrainedModel",
|
||||
"Dots1Model",
|
||||
"Dots1ForCausalLM",
|
||||
]
|
@ -318,6 +318,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GemmaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -339,6 +339,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Gemma2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -422,6 +422,7 @@ class Gemma3PreTrainedModel(PreTrainedModel):
|
||||
"SiglipMultiheadAttentionPoolingHead",
|
||||
]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
@ -776,6 +777,8 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
|
||||
)
|
||||
class Gemma3Model(Gemma3PreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__(config)
|
||||
|
@ -727,6 +727,9 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
|
||||
|
||||
|
||||
class Gemma3Model(PaliGemmaModel):
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
@ -335,6 +335,7 @@ class GlmPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GlmDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -343,6 +343,7 @@ class Glm4PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Glm4DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
28
src/transformers/models/glm4v/__init__.py
Normal file
28
src/transformers/models/glm4v/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_glm4v import *
|
||||
from .modeling_glm4v import *
|
||||
from .processing_glm4v import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
354
src/transformers/models/glm4v/configuration_glm4v.py
Normal file
354
src/transformers/models/glm4v/configuration_glm4v.py
Normal file
@ -0,0 +1,354 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/glm4v/modular_glm4v.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_glm4v.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class Glm4vVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Glm4vVisionModel`]. It is used to instantiate an Glm4vVisionModel
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield
|
||||
a similar configuration to that of
|
||||
GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1536):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
depth (`int`, *optional*, defaults to 24):
|
||||
Number of layers (depth) in the model.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries, keys and values.
|
||||
intermediate_size (`int`, *optional*, defaults to 13696):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"selu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability for attention weights.
|
||||
projection_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability for the projection layer.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
image_size (`int` or `list[int]`, *optional*, defaults to `[336, 336]`):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to `14`):
|
||||
The size (resolution) of each patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
out_hidden_size (`int`, *optional*, defaults to 4096):
|
||||
The output hidden size of the vision model.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
spatial_merge_size (`int`, *optional*, defaults to 2):
|
||||
The size used for merging spatial dimensions.
|
||||
temporal_patch_size (`int`, *optional*, defaults to 2):
|
||||
The size used for patches along the temporal dimension.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Glm4vVisionConfig, Glm4vVisionModel
|
||||
|
||||
>>> # Initializing a Glm4vVisionConfig GLM-4.1V-9B style configuration
|
||||
>>> configuration = Glm4vVisionConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration
|
||||
>>> model = Glm4vVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "glm4v"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=24,
|
||||
hidden_size=1536,
|
||||
hidden_act="silu",
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
num_heads=12,
|
||||
in_channels=3,
|
||||
image_size=336,
|
||||
patch_size=14,
|
||||
rms_norm_eps=1e-05,
|
||||
spatial_merge_size=2,
|
||||
temporal_patch_size=1,
|
||||
out_hidden_size=4096,
|
||||
intermediate_size=13696,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.out_hidden_size = out_hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
|
||||
class Glm4vTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a
|
||||
GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of
|
||||
GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151552):
|
||||
Vocabulary size of the Glm4v model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Glm4vModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 13696):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 40):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 2):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
image_token_id (`int`, *optional*):
|
||||
Token index used as placeholder for image embeddings.
|
||||
video_token_id (`int`, *optional*):
|
||||
Token index used as placeholder for video embeddings.
|
||||
|
||||
```python
|
||||
>>> from transformers import Glm4vTextModel, Glm4vConfig
|
||||
|
||||
>>> # Initializing a GLM-4.1V style configuration
|
||||
>>> configuration = Glm4vConfig()
|
||||
|
||||
>>> # Initializing a model from the GLM-4.1V style configuration
|
||||
>>> model = Glm4vTextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "glm4v_text"
|
||||
base_config_key = "text_config"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `Glm4v`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151552,
|
||||
hidden_size=4096,
|
||||
intermediate_size=13696,
|
||||
num_hidden_layers=40,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=2,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
attention_dropout=0.0,
|
||||
rope_scaling=None,
|
||||
image_token_id=None,
|
||||
video_token_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
class Glm4vConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Glm4vModel`]. It is used to instantiate a
|
||||
GLM-4.1V model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of
|
||||
GLM-4.1V-9B-Thinking [THUDM/GLM-4.1V-9B-Thinking](https://huggingface.co/THUDM/GLM-4.1V-9B-Thinking).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
text_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vTextConfig`):
|
||||
The config object or dictionary of the text backbone.
|
||||
vision_config (`Union[PreTrainedConfig, dict]`, *optional*, defaults to `Glm4vVisionConfig`):
|
||||
The config object or dictionary of the vision backbone.
|
||||
image_token_id (`int`, *optional*, defaults to 151343):
|
||||
The image token index to encode the image prompt.
|
||||
video_token_id (`int`, *optional*, defaults to 151344):
|
||||
The video token index to encode the image prompt.
|
||||
image_start_token_id (`int`, *optional*, defaults to 151339):
|
||||
The image start token index to encode the start of image.
|
||||
image_end_token_id (`int`, *optional*, defaults to 151340):
|
||||
The image end token index to encode the end of image.
|
||||
video_start_token_id (`int`, *optional*, defaults to 151341):
|
||||
The video start token index to encode the start of video.
|
||||
video_end_token_id (`int`, *optional*, defaults to 151342):
|
||||
The video end token index to encode the end of video.
|
||||
|
||||
```python
|
||||
>>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig
|
||||
|
||||
>>> # Initializing a GLM-4.1V style configuration
|
||||
>>> configuration = Glm4vConfig()
|
||||
|
||||
>>> # Initializing a model from the GLM-4.1V style configuration
|
||||
>>> model = Glm4vForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "glm4v"
|
||||
sub_configs = {"vision_config": Glm4vVisionConfig, "text_config": Glm4vTextConfig}
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_config=None,
|
||||
vision_config=None,
|
||||
image_token_id=151343,
|
||||
video_token_id=151344,
|
||||
image_start_token_id=151339,
|
||||
image_end_token_id=151340,
|
||||
video_start_token_id=151341,
|
||||
video_end_token_id=151342,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if isinstance(vision_config, dict):
|
||||
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
||||
elif vision_config is None:
|
||||
self.vision_config = self.sub_configs["vision_config"]()
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
||||
elif text_config is None:
|
||||
# For BC use all kwargs to init `TextConfig`
|
||||
self.text_config = self.sub_configs["text_config"](**kwargs)
|
||||
|
||||
self.image_token_id = image_token_id
|
||||
self.video_token_id = video_token_id
|
||||
self.video_start_token_id = video_start_token_id
|
||||
self.video_end_token_id = video_end_token_id
|
||||
self.image_start_token_id = image_start_token_id
|
||||
self.image_end_token_id = image_end_token_id
|
||||
|
||||
|
||||
__all__ = ["Glm4vConfig", "Glm4vTextConfig"]
|
645
src/transformers/models/glm4v/convert_glm4v_mgt_weights_to_hf.py
Normal file
645
src/transformers/models/glm4v/convert_glm4v_mgt_weights_to_hf.py
Normal file
@ -0,0 +1,645 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
# Avoid Using Megatron Lib
|
||||
class UnpicklerWrapper(pickle.Unpickler):
|
||||
def find_class(self, mod_name, name):
|
||||
class DummyClass:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
if mod_name.startswith("megatron") or mod_name.startswith("glm") or mod_name.startswith("__main__"):
|
||||
return DummyClass
|
||||
return super().find_class(mod_name, name)
|
||||
|
||||
|
||||
pickle.Unpickler = UnpicklerWrapper
|
||||
|
||||
|
||||
def dict_access_multi(a_dict, keys):
|
||||
if len(keys) == 0:
|
||||
return a_dict
|
||||
return dict_access_multi(a_dict[keys[0]], keys[1:])
|
||||
|
||||
|
||||
def merge_qkv(
|
||||
sd_list,
|
||||
original_tp,
|
||||
num_attention_heads,
|
||||
multi_query_group_num,
|
||||
attention_dim,
|
||||
multi_query_attention,
|
||||
interleaved_qkv,
|
||||
):
|
||||
if not multi_query_attention and interleaved_qkv:
|
||||
return torch.cat(sd_list, dim=0)
|
||||
q, k, v = [], [], []
|
||||
for sd in sd_list:
|
||||
if multi_query_attention:
|
||||
q_, k_, v_ = sd.split(
|
||||
[
|
||||
num_attention_heads * attention_dim // original_tp,
|
||||
multi_query_group_num * attention_dim // original_tp,
|
||||
multi_query_group_num * attention_dim // original_tp,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
q_, k_, v_ = sd.chunk(dim=0, chunks=3)
|
||||
q.append(q_.clone())
|
||||
k.append(k_.clone())
|
||||
v.append(v_.clone())
|
||||
q = torch.cat(q, dim=0)
|
||||
k = torch.cat(k, dim=0)
|
||||
v = torch.cat(v, dim=0)
|
||||
if not interleaved_qkv:
|
||||
rotary_dim = attention_dim // 2
|
||||
half_rot = rotary_dim // 2
|
||||
perm_rot = torch.empty(rotary_dim, dtype=torch.long)
|
||||
perm_rot[0::2] = torch.arange(0, half_rot)
|
||||
perm_rot[1::2] = torch.arange(half_rot, rotary_dim)
|
||||
if q.dim() == 2:
|
||||
qh = q.view(num_attention_heads, attention_dim, -1)
|
||||
kh = k.view(multi_query_group_num, attention_dim, -1)
|
||||
qh[:, :rotary_dim, :] = qh[:, perm_rot, :]
|
||||
kh[:, :rotary_dim, :] = kh[:, perm_rot, :]
|
||||
q = qh.reshape(-1, q.size(-1))
|
||||
k = kh.reshape(-1, k.size(-1))
|
||||
else:
|
||||
qh = q.view(num_attention_heads, attention_dim)
|
||||
kh = k.view(multi_query_group_num, attention_dim)
|
||||
qh[:, :rotary_dim] = qh[:, perm_rot]
|
||||
kh[:, :rotary_dim] = kh[:, perm_rot]
|
||||
q = qh.reshape(-1)
|
||||
k = kh.reshape(-1)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def merge_glu(sd_list):
|
||||
return torch.cat(
|
||||
[sd.chunk(dim=0, chunks=2)[0].clone() for sd in sd_list]
|
||||
+ [sd.chunk(dim=0, chunks=2)[1].clone() for sd in sd_list],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
def merge_glu_vit(sd_list, original_tp=None):
|
||||
gate_proj = torch.cat([sd.chunk(dim=0, chunks=2)[0].clone() for sd in sd_list], dim=0)
|
||||
up_proj = torch.cat([sd.chunk(dim=0, chunks=2)[1].clone() for sd in sd_list], dim=0)
|
||||
return gate_proj, up_proj
|
||||
|
||||
|
||||
def split_glu(sd, cnt, idx):
|
||||
return torch.cat(
|
||||
(
|
||||
sd.chunk(dim=0, chunks=2)[0].chunk(cnt, dim=0)[idx].clone(),
|
||||
sd.chunk(dim=0, chunks=2)[1].chunk(cnt, dim=0)[idx].clone(),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
||||
def merge_qkv_vit(sd_list, original_tp=None):
|
||||
q, k, v = [], [], []
|
||||
for sd in sd_list:
|
||||
q_, k_, v_ = sd.chunk(dim=0, chunks=3)
|
||||
q.append(q_.clone().contiguous())
|
||||
k.append(k_.clone().contiguous())
|
||||
v.append(v_.clone().contiguous())
|
||||
q = torch.cat(q, dim=0)
|
||||
k = torch.cat(k, dim=0)
|
||||
v = torch.cat(v, dim=0)
|
||||
combined = torch.cat([q, k, v], dim=0)
|
||||
return combined
|
||||
|
||||
|
||||
def merge_tensors_vit(
|
||||
tp_sd: list[dict],
|
||||
keys: list[str],
|
||||
original_tp: int,
|
||||
target_tp: int,
|
||||
slice_dim: Optional[int] = None,
|
||||
merge_fn: Optional[Callable] = None,
|
||||
):
|
||||
cnt = original_tp // target_tp
|
||||
sd_list = [dict_access_multi(tp_sd[i], keys) for i in range(cnt)]
|
||||
if slice_dim is not None:
|
||||
return torch.cat(sd_list, dim=slice_dim)
|
||||
assert merge_fn is not None
|
||||
return merge_fn(sd_list, original_tp)
|
||||
|
||||
|
||||
def merge_tensors(
|
||||
tp_sd,
|
||||
keys,
|
||||
original_tp,
|
||||
target_tp,
|
||||
current_tp,
|
||||
slice_dim=None,
|
||||
merge_fn=None,
|
||||
):
|
||||
cnt = original_tp // target_tp
|
||||
offset = cnt * current_tp
|
||||
sd_list = [dict_access_multi(tp_sd[i + offset], keys) for i in range(cnt)]
|
||||
if slice_dim is not None:
|
||||
return torch.cat(sd_list, dim=slice_dim)
|
||||
assert merge_fn is not None
|
||||
return merge_fn(sd_list)
|
||||
|
||||
|
||||
def save_sharded_model(state_dict, output_path, max_shard_size_gb=5, num_layers=40, vision_num_layers=24):
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
layered_dict = {}
|
||||
for layer_idx in range(num_layers):
|
||||
layer_key = f"layer_{layer_idx}"
|
||||
layered_dict[layer_key] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if f"model.language_model.layers.{layer_idx}." in key:
|
||||
layered_dict[layer_key][key] = value
|
||||
|
||||
for layer_idx in range(vision_num_layers):
|
||||
layer_key = f"visual_layer_{layer_idx}"
|
||||
layered_dict[layer_key] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if f"model.visual.blocks.{layer_idx}." in key:
|
||||
layered_dict[layer_key][key] = value
|
||||
|
||||
layered_dict["others"] = {}
|
||||
for key, value in state_dict.items():
|
||||
if not any(f"model.language_model.layers.{i}." in key for i in range(num_layers)) and not any(
|
||||
f"model.visual.blocks.{i}." in key for i in range(vision_num_layers)
|
||||
):
|
||||
layered_dict["others"][key] = value
|
||||
|
||||
# Determine layer ordering
|
||||
layer_order = []
|
||||
for i in range(40):
|
||||
layer_order.append(f"layer_{i}")
|
||||
for i in range(24):
|
||||
layer_order.append(f"visual_layer_{i}")
|
||||
layer_order.append("others")
|
||||
|
||||
# Calculate sizes and create shards by layer
|
||||
param_sizes = {}
|
||||
shards = []
|
||||
current_shard = {}
|
||||
current_shard_size = 0
|
||||
max_shard_size_bytes = max_shard_size_gb * 1024 * 1024 * 1024
|
||||
|
||||
for layer_key in layer_order:
|
||||
layer_weights = layered_dict[layer_key]
|
||||
layer_size = sum(param.numel() * param.element_size() for param in layer_weights.values())
|
||||
if current_shard_size + layer_size > max_shard_size_bytes and current_shard:
|
||||
shards.append(current_shard)
|
||||
current_shard = {}
|
||||
current_shard_size = 0
|
||||
for param_name, param in layer_weights.items():
|
||||
current_shard[param_name] = param
|
||||
current_shard_size += param.numel() * param.element_size()
|
||||
param_sizes[param_name] = param.numel() * param.element_size()
|
||||
if current_shard:
|
||||
shards.append(current_shard)
|
||||
index_dict = {"metadata": {"total_size": sum(param_sizes.values())}, "weight_map": {}}
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_filename = f"model-{i + 1:05d}-of-{len(shards):05d}.safetensors"
|
||||
shard_path = os.path.join(output_path, shard_filename)
|
||||
|
||||
for param_name in shard.keys():
|
||||
index_dict["weight_map"][param_name] = shard_filename
|
||||
|
||||
save_file(shard, shard_path, metadata={"format": "pt"})
|
||||
print(f"Saved shard {i + 1}/{len(shards)}: {shard_filename}")
|
||||
print(f" Shard size: {sum(p.numel() * p.element_size() for p in shard.values()) / (1024**3):.2f} GB")
|
||||
print(f" Keys in shard: {len(shard)}")
|
||||
|
||||
index_path = os.path.join(output_path, "model.safetensors.index.json")
|
||||
with open(index_path, "w") as f:
|
||||
json.dump(index_dict, f, indent=2)
|
||||
|
||||
return len(shards)
|
||||
|
||||
|
||||
def merge_tp_weights(model_path, output_path, vllm_config_path=None):
|
||||
tp_size = 0
|
||||
for item in Path(model_path).iterdir():
|
||||
if item.is_dir():
|
||||
match = re.match(r"mp_rank_(\d{2})", item.name)
|
||||
if match:
|
||||
tp = int(match.group(1))
|
||||
tp_size = max(tp_size, tp + 1)
|
||||
|
||||
print(f"Detected tensor parallel degree TP={tp_size}")
|
||||
|
||||
if tp_size <= 1:
|
||||
print("Model is already at TP=1, no need to merge")
|
||||
return
|
||||
|
||||
print(f"Loading vLLM configuration file: {vllm_config_path}")
|
||||
with open(vllm_config_path, "r") as f:
|
||||
model_config = json.load(f)
|
||||
num_layers = model_config.get("num_layers", 40)
|
||||
vision_num_layers = model_config.get("vision_config", {}).get("num_hidden_layers", 24)
|
||||
num_heads = model_config.get("num_attention_heads", 32)
|
||||
num_kv_heads = model_config.get("num_query_groups", 2)
|
||||
hidden_size = model_config.get("hidden_size", 4096)
|
||||
head_dim = model_config.get("attention_dim", hidden_size // num_heads)
|
||||
|
||||
print(
|
||||
f"Model parameters: num_layers={num_layers}, vision_num_layers={vision_num_layers}, "
|
||||
f"num_heads={num_heads}, multi_query_group_num={num_kv_heads}, hidden_size={hidden_size}"
|
||||
)
|
||||
|
||||
weights = []
|
||||
for tp_rank in range(tp_size):
|
||||
print(f"Loading TP shard {tp_rank}...")
|
||||
weight_path = Path(model_path) / f"mp_rank_{tp_rank:02d}" / "model_optim_rng.pt"
|
||||
sd = torch.load(weight_path, map_location="cpu", pickle_module=pickle)
|
||||
|
||||
for k in list(sd.keys()):
|
||||
if "_extra_state" in k or "dummy_parameter" in k:
|
||||
sd.pop(k)
|
||||
|
||||
if "model" in sd:
|
||||
weights.append(sd["model"])
|
||||
else:
|
||||
raise ValueError(f"'model' key not found in {weight_path}")
|
||||
|
||||
if not weights:
|
||||
raise ValueError("No valid weight files found")
|
||||
|
||||
print("Merging tensor parallel weights...")
|
||||
original_pp_enabled = os.path.exists(Path(model_path) / "mp_rank_00_000")
|
||||
original_tp, original_pp = tp_size, 1
|
||||
target_tp = 1
|
||||
print(f"TP and PP INFO: original_tp: {original_tp}, original_pp:{original_pp}, target_tp: {target_tp}")
|
||||
mgt_sd = [
|
||||
[
|
||||
torch.load(
|
||||
Path(model_path)
|
||||
/ (f"mp_rank_{j:02d}_{i:03d}" if original_pp_enabled else f"mp_rank_{j:02d}")
|
||||
/ "model_optim_rng.pt",
|
||||
map_location="cpu",
|
||||
pickle_module=pickle,
|
||||
)
|
||||
for j in range(original_tp)
|
||||
]
|
||||
for i in range(original_pp)
|
||||
]
|
||||
|
||||
interleaved_qkv = False
|
||||
multi_query_attention = True
|
||||
num_attention_heads = num_heads
|
||||
multi_query_group_num = num_kv_heads
|
||||
attention_dim = head_dim
|
||||
complete_state_dict = {}
|
||||
keys = ["model"]
|
||||
rank = 0
|
||||
|
||||
# LLM
|
||||
for pp in range(original_pp):
|
||||
layer_i = 0
|
||||
mgt_encoder_tp_0 = dict_access_multi(mgt_sd[pp][rank], keys)
|
||||
|
||||
while f"decoder.layers.{layer_i}.self_attention.linear_qkv.layer_norm_weight" in mgt_encoder_tp_0:
|
||||
complete_state_dict.update(
|
||||
{
|
||||
f"model.language_model.layers.{layer_i}.input_layernorm.weight": mgt_encoder_tp_0[
|
||||
f"decoder.layers.{layer_i}.self_attention.linear_qkv.layer_norm_weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": mgt_encoder_tp_0[
|
||||
f"decoder.layers.{layer_i}.mlp.linear_fc1.layer_norm_weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.post_self_attn_layernorm.weight": mgt_encoder_tp_0[
|
||||
f"decoder.layers.{layer_i}.post_self_attn_layernorm.weight"
|
||||
],
|
||||
f"model.language_model.layers.{layer_i}.post_mlp_layernorm.weight": mgt_encoder_tp_0[
|
||||
f"decoder.layers.{layer_i}.post_mlp_layernorm.weight"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
q, k, v = merge_tensors(
|
||||
tp_sd=mgt_sd[pp],
|
||||
keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_qkv.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
current_tp=0,
|
||||
merge_fn=lambda sd_list: merge_qkv(
|
||||
sd_list,
|
||||
original_tp,
|
||||
num_attention_heads,
|
||||
multi_query_group_num,
|
||||
attention_dim,
|
||||
multi_query_attention,
|
||||
interleaved_qkv,
|
||||
),
|
||||
)
|
||||
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight"] = q.clone()
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight"] = k.clone()
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight"] = v.clone()
|
||||
|
||||
if f"decoder.layers.{layer_i}.self_attention.linear_qkv.bias" in mgt_encoder_tp_0:
|
||||
q_bias, k_bias, v_bias = merge_tensors(
|
||||
tp_sd=mgt_sd[pp],
|
||||
keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_qkv.bias"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
current_tp=0,
|
||||
merge_fn=lambda sd_list: merge_qkv(
|
||||
sd_list,
|
||||
original_tp,
|
||||
num_attention_heads,
|
||||
multi_query_group_num,
|
||||
attention_dim,
|
||||
multi_query_attention,
|
||||
interleaved_qkv,
|
||||
),
|
||||
)
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.bias"] = q_bias.clone()
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.bias"] = k_bias.clone()
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.bias"] = v_bias.clone()
|
||||
|
||||
o_proj = merge_tensors(
|
||||
tp_sd=mgt_sd[pp],
|
||||
keys=keys + [f"decoder.layers.{layer_i}.self_attention.linear_proj.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
current_tp=0,
|
||||
slice_dim=1,
|
||||
)
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight"] = o_proj.clone()
|
||||
|
||||
# MLP - Use gate_up_proj
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.mlp.gate_up_proj.weight"] = merge_tensors(
|
||||
tp_sd=mgt_sd[pp],
|
||||
keys=keys + [f"decoder.layers.{layer_i}.mlp.linear_fc1.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
current_tp=0,
|
||||
merge_fn=merge_glu,
|
||||
).clone()
|
||||
complete_state_dict[f"model.language_model.layers.{layer_i}.mlp.down_proj.weight"] = merge_tensors(
|
||||
tp_sd=mgt_sd[pp],
|
||||
keys=keys + [f"decoder.layers.{layer_i}.mlp.linear_fc2.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
current_tp=0,
|
||||
slice_dim=1,
|
||||
)
|
||||
layer_i += 1
|
||||
|
||||
# Embedd Model, LM Head, and Norm
|
||||
embed_tokens = merge_tensors(
|
||||
tp_sd=mgt_sd[0],
|
||||
keys=["model", "embedding.word_embeddings.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
current_tp=0,
|
||||
slice_dim=0,
|
||||
)
|
||||
complete_state_dict["model.language_model.embed_tokens.weight"] = embed_tokens.clone()
|
||||
lm_head = merge_tensors(
|
||||
tp_sd=mgt_sd[-1],
|
||||
keys=["model", "output_layer.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
current_tp=0,
|
||||
slice_dim=0,
|
||||
)
|
||||
complete_state_dict["lm_head.weight"] = lm_head.clone()
|
||||
complete_state_dict["model.language_model.norm.weight"] = mgt_sd[-1][rank]["model"][
|
||||
"decoder.final_layernorm.weight"
|
||||
].clone()
|
||||
mgt_encoder_tp_0 = dict_access_multi(mgt_sd[0][0], keys)
|
||||
|
||||
# VLM
|
||||
for layer_i in range(vision_num_layers):
|
||||
complete_state_dict[f"model.visual.blocks.{layer_i}.norm1.weight"] = mgt_encoder_tp_0[
|
||||
f"vision_model.transformer.layers.{layer_i}.input_layernorm.weight"
|
||||
]
|
||||
complete_state_dict[f"model.visual.blocks.{layer_i}.norm2.weight"] = mgt_encoder_tp_0[
|
||||
f"vision_model.transformer.layers.{layer_i}.pre_mlp_layernorm.weight"
|
||||
]
|
||||
|
||||
qkv_weight = merge_tensors_vit(
|
||||
tp_sd=mgt_sd[0],
|
||||
keys=keys + [f"vision_model.transformer.layers.{layer_i}.self_attention.linear_qkv.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
merge_fn=merge_qkv_vit,
|
||||
)
|
||||
complete_state_dict[f"model.visual.blocks.{layer_i}.attn.qkv.weight"] = qkv_weight.clone()
|
||||
|
||||
proj_weight = merge_tensors_vit(
|
||||
tp_sd=mgt_sd[0],
|
||||
keys=keys + [f"vision_model.transformer.layers.{layer_i}.self_attention.linear_proj.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
slice_dim=1,
|
||||
)
|
||||
complete_state_dict[f"model.visual.blocks.{layer_i}.attn.proj.weight"] = proj_weight.clone()
|
||||
|
||||
gate_proj_weight, up_proj_weight = merge_tensors_vit(
|
||||
tp_sd=mgt_sd[0],
|
||||
keys=keys + [f"vision_model.transformer.layers.{layer_i}.mlp.linear_fc1.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
merge_fn=lambda sd_list, original_tp: merge_glu_vit(sd_list, original_tp),
|
||||
)
|
||||
complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.gate_proj.weight"] = gate_proj_weight.clone()
|
||||
complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.up_proj.weight"] = up_proj_weight.clone()
|
||||
|
||||
down_proj_weight = merge_tensors_vit(
|
||||
tp_sd=mgt_sd[0],
|
||||
keys=keys + [f"vision_model.transformer.layers.{layer_i}.mlp.linear_fc2.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
slice_dim=1,
|
||||
)
|
||||
complete_state_dict[f"model.visual.blocks.{layer_i}.mlp.down_proj.weight"] = down_proj_weight.clone()
|
||||
|
||||
complete_state_dict["model.visual.downsample.weight"] = (
|
||||
mgt_sd[0][0]["model"]["vision_model.downsample.weight"].clone().contiguous()
|
||||
)
|
||||
complete_state_dict["model.visual.downsample.bias"] = (
|
||||
mgt_sd[0][0]["model"]["vision_model.downsample.bias"].clone().contiguous()
|
||||
)
|
||||
|
||||
# Merger
|
||||
gate_proj, up_proj = merge_tensors_vit(
|
||||
tp_sd=mgt_sd[0],
|
||||
keys=keys + ["vision_projection.encoder.linear_fc1.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
merge_fn=merge_glu_vit,
|
||||
)
|
||||
|
||||
down_proj = merge_tensors_vit(
|
||||
tp_sd=mgt_sd[0],
|
||||
keys=keys + ["vision_projection.encoder.linear_fc2.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
slice_dim=1,
|
||||
)
|
||||
proj = merge_tensors_vit(
|
||||
tp_sd=mgt_sd[0],
|
||||
keys=keys + ["vision_projection.encoder.linear_fc_extra.weight"],
|
||||
original_tp=original_tp,
|
||||
target_tp=target_tp,
|
||||
slice_dim=0,
|
||||
)
|
||||
|
||||
complete_state_dict["model.visual.merger.gate_proj.weight"] = gate_proj.clone().contiguous()
|
||||
complete_state_dict["model.visual.merger.up_proj.weight"] = up_proj.clone().contiguous()
|
||||
complete_state_dict["model.visual.merger.down_proj.weight"] = down_proj.clone().contiguous()
|
||||
complete_state_dict["model.visual.merger.proj.weight"] = proj.clone().contiguous()
|
||||
|
||||
complete_state_dict["model.visual.merger.post_projection_norm.weight"] = (
|
||||
mgt_sd[0][0]["model"]["vision_projection.encoder.layer_norm.weight"].clone().contiguous()
|
||||
)
|
||||
complete_state_dict["model.visual.merger.post_projection_norm.bias"] = (
|
||||
mgt_sd[0][0]["model"]["vision_projection.encoder.layer_norm.bias"].clone().contiguous()
|
||||
)
|
||||
complete_state_dict["model.visual.embeddings.position_embedding.weight"] = (
|
||||
mgt_sd[0][0]["model"]["vision_model.position_embeddings.weight"].clone().contiguous()
|
||||
)
|
||||
complete_state_dict["model.visual.patch_embed.proj.weight"] = (
|
||||
mgt_sd[0][0]["model"]["vision_model.conv3d.weight"].clone().contiguous()
|
||||
)
|
||||
complete_state_dict["model.visual.patch_embed.proj.bias"] = (
|
||||
mgt_sd[0][0]["model"]["vision_model.conv3d.bias"].clone().contiguous()
|
||||
)
|
||||
|
||||
# Check for additional vision model norm layers mentioned in the expected output
|
||||
if "vision_model.post_conv_layernorm.weight" in mgt_encoder_tp_0:
|
||||
complete_state_dict["model.visual.post_conv_layernorm.weight"] = (
|
||||
mgt_sd[0][0]["model"]["vision_model.post_conv_layernorm.weight"].clone().contiguous()
|
||||
)
|
||||
|
||||
if "vision_model.post_layernorm.weight" in mgt_encoder_tp_0:
|
||||
complete_state_dict["model.visual.post_layernorm.weight"] = (
|
||||
mgt_sd[0][0]["model"]["vision_model.post_layernorm.weight"].clone().contiguous()
|
||||
)
|
||||
|
||||
print(f"Total keys in state dict: {len(complete_state_dict)}")
|
||||
|
||||
for key, value in complete_state_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
complete_state_dict[key] = value.to(torch.bfloat16)
|
||||
print("Converted all tensors to bfloat16")
|
||||
# Save Model weight
|
||||
save_sharded_model(
|
||||
complete_state_dict,
|
||||
output_path=output_path,
|
||||
max_shard_size_gb=5,
|
||||
num_layers=num_layers,
|
||||
vision_num_layers=vision_num_layers,
|
||||
)
|
||||
|
||||
hf_config = {
|
||||
"architectures": ["Glm4vForConditionalGeneration"],
|
||||
"model_type": "glm4v",
|
||||
"attention_bias": model_config.get("add_qkv_bias", True),
|
||||
"attention_dropout": 0.0,
|
||||
"pad_token_id": model_config.get("pad_token_id", 151329),
|
||||
"eos_token_id": model_config.get("eos_token_id", [151329, 151336, 151338]),
|
||||
"image_start_token_id": model_config.get("image_start_token_id", 151339),
|
||||
"image_end_token_id": model_config.get("image_end_token_id", 151340),
|
||||
"video_start_token_id": model_config.get("video_start_token_id", 151341),
|
||||
"video_end_token_id": model_config.get("video_end_token_id", 151342),
|
||||
"image_token_id": model_config.get("image_token_id", 151343),
|
||||
"video_token_id": model_config.get("video_token_id", 151344),
|
||||
"hidden_act": model_config.get("hidden_act", "silu"),
|
||||
"hidden_size": model_config.get("hidden_size", 4096),
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": model_config.get("ffn_hidden_size", 13696),
|
||||
"max_position_embeddings": model_config.get("seq_length", 32768),
|
||||
"num_attention_heads": model_config.get("num_attention_heads", 32),
|
||||
"num_hidden_layers": model_config.get("num_layers", 40),
|
||||
"num_key_value_heads": model_config.get("multi_query_group_num", 2),
|
||||
"rms_norm_eps": model_config.get("layernorm_epsilon", 1e-05),
|
||||
"rope_theta": model_config.get("rotary_base", 10000.0),
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": model_config.get("torch_dtype", "bfloat16"),
|
||||
"transformers_version": "4.53.0dev",
|
||||
"use_cache": model_config.get("use_cache", True),
|
||||
"vocab_size": model_config.get("vocab_size", 151552),
|
||||
"partial_rotary_factor": 0.5,
|
||||
}
|
||||
|
||||
if "vision_config" in model_config:
|
||||
vision_config = {
|
||||
"hidden_size": model_config["vision_config"].get("hidden_size", 1536),
|
||||
"depth": model_config["vision_config"].get("num_layers", 24),
|
||||
"num_heads": model_config["vision_config"].get("num_attention_heads", 12),
|
||||
"attention_bias": model_config["vision_config"].get("attention_bias", False),
|
||||
"intermediate_size": model_config.get("ffn_hidden_size", 13696),
|
||||
"hidden_act": model_config["vision_config"].get("hidden_act", "silu"),
|
||||
"hidden_dropout_prob": model_config["vision_config"].get("hidden_dropout_prob", 0.0),
|
||||
"initializer_range": 0.02,
|
||||
"image_size": model_config["vision_config"].get("image_size", 336),
|
||||
"patch_size": model_config["vision_config"].get("patch_size", 14),
|
||||
"out_hidden_size": model_config.get("hidden_size", 4096),
|
||||
"rms_norm_eps": model_config["vision_config"].get("layernorm_epsilon", 1e-05),
|
||||
"spatial_merge_size": model_config["vision_config"].get("downsample_ratio", 2),
|
||||
"temporal_patch_size": model_config["vision_config"].get("t_patch", 2),
|
||||
}
|
||||
hf_config["vision_config"] = vision_config
|
||||
|
||||
if "rope_scaling" in model_config:
|
||||
hf_config["rope_scaling"] = model_config["rope_scaling"]
|
||||
|
||||
config_path = os.path.join(output_path, "config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(hf_config, f, indent=2)
|
||||
|
||||
print(f"Conversion complete! Model saved to {output_path}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Convert Megatron model to HuggingFace format")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to Megatron model directory",
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Output path for HuggingFace model directory")
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, help="Path to vLLM configuration file for creating HuggingFace config"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
merge_tp_weights(args.model_path, args.output_path, args.config_path)
|
467
src/transformers/models/glm4v/image_processing_glm4v.py
Normal file
467
src/transformers/models/glm4v/image_processing_glm4v.py
Normal file
@ -0,0 +1,467 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for GLM-4.1V."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from ...image_transforms import (
|
||||
convert_to_rgb,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
get_image_size,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_flat_list_of_images,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, logging
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smart_resize(
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
temporal_factor: int = 2,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 112 * 112,
|
||||
max_pixels: int = 14 * 14 * 2 * 2 * 2 * 6144,
|
||||
):
|
||||
if num_frames < temporal_factor:
|
||||
raise ValueError(f"t:{num_frames} must be larger than temporal_factor:{temporal_factor}")
|
||||
if height < factor or width < factor:
|
||||
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
||||
elif max(height, width) / min(height, width) > 200:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = round(height / factor) * factor
|
||||
w_bar = round(width / factor) * factor
|
||||
t_bar = round(num_frames / temporal_factor) * temporal_factor
|
||||
|
||||
if t_bar * h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((num_frames * height * width) / max_pixels)
|
||||
h_bar = math.floor(height / beta / factor) * factor
|
||||
w_bar = math.floor(width / beta / factor) * factor
|
||||
elif t_bar * h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (num_frames * height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
class Glm4vImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a GLM-4V image processor that dynamically resizes images based on the original images.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}`):
|
||||
Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
|
||||
in the `preprocess` method. Available options are:
|
||||
- `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
|
||||
Do NOT keep the aspect ratio.
|
||||
- `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
|
||||
the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
|
||||
less or equal to `longest_edge`.
|
||||
- `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
|
||||
aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
|
||||
`max_width`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
||||
Resampling filter to use when resizing the image.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to 2):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to 2):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values", "image_grid_thw"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
else:
|
||||
size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
|
||||
self.size = size
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.merge_size = merge_size
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: Union[ImageInput, VideoInput],
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
temporal_patch_size: Optional[int] = None,
|
||||
merge_size: Optional[int] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
||||
vision_info (`List[Dict]`, *optional*):
|
||||
Optional list of dictionaries containing additional information about vision inputs.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Scale factor to use if rescaling the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if do_rescale and is_scaled_image(images[0]):
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
||||
resized_height, resized_width = height, width
|
||||
processed_images = []
|
||||
for image in images:
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
num_frames=temporal_patch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
temporal_factor=temporal_patch_size,
|
||||
factor=patch_size * merge_size,
|
||||
)
|
||||
image = resize(
|
||||
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(
|
||||
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
processed_images.append(image)
|
||||
|
||||
patches = np.array(processed_images)
|
||||
if data_format == ChannelDimension.LAST:
|
||||
patches = patches.transpose(0, 3, 1, 2)
|
||||
if patches.shape[0] % temporal_patch_size != 0:
|
||||
repeats = np.repeat(
|
||||
patches[-1][np.newaxis], temporal_patch_size - (patches.shape[0] % temporal_patch_size), axis=0
|
||||
)
|
||||
patches = np.concatenate([patches, repeats], axis=0)
|
||||
channel = patches.shape[1]
|
||||
grid_t = patches.shape[0] // temporal_patch_size
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
patches = patches.reshape(
|
||||
grid_t,
|
||||
temporal_patch_size,
|
||||
channel,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
|
||||
return flatten_patches, (grid_t, grid_h, grid_w)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
temporal_patch_size: Optional[int] = None,
|
||||
merge_size: Optional[int] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
videos (`VideoInput`):
|
||||
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
||||
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
The max pixels of the image to resize the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
"""
|
||||
|
||||
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
else:
|
||||
size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
|
||||
merge_size = merge_size if merge_size is not None else self.merge_size
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
if images is not None:
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
validate_preprocess_arguments(
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
data = {}
|
||||
if images is not None:
|
||||
pixel_values, vision_grid_thws = [], []
|
||||
for image in images:
|
||||
patches, image_grid_thw = self._preprocess(
|
||||
image,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
merge_size=merge_size,
|
||||
data_format=data_format,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
pixel_values.extend(patches)
|
||||
vision_grid_thws.append(image_grid_thw)
|
||||
pixel_values = np.array(pixel_values)
|
||||
vision_grid_thws = np.array(vision_grid_thws)
|
||||
data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||
"""
|
||||
A utility that returns number of image patches for a given image size.
|
||||
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
images_kwargs (`dict`, *optional*)
|
||||
Any kwargs to override defaults of the image processor.
|
||||
Returns:
|
||||
`int`: Number of image patches per image.
|
||||
"""
|
||||
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
|
||||
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
|
||||
|
||||
factor = patch_size * merge_size
|
||||
resized_height, resized_width = smart_resize(
|
||||
t=self.temporal_patch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
factor=factor,
|
||||
t_factor=self.temporal_patch_size,
|
||||
)
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
return grid_h * grid_w
|
||||
|
||||
|
||||
__all__ = ["Glm4vImageProcessor"]
|
364
src/transformers/models/glm4v/image_processing_glm4v_fast.py
Normal file
364
src/transformers/models/glm4v/image_processing_glm4v_fast.py
Normal file
@ -0,0 +1,364 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fast Image processor class for GLM-4.1V."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...image_processing_utils import (
|
||||
BatchFeature,
|
||||
)
|
||||
from ...image_processing_utils_fast import (
|
||||
BaseImageProcessorFast,
|
||||
DefaultFastImageProcessorKwargs,
|
||||
group_images_by_shape,
|
||||
reorder_images,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
make_flat_list_of_images,
|
||||
valid_images,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
auto_docstring,
|
||||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
is_torchvision_v2_available,
|
||||
logging,
|
||||
)
|
||||
from ...video_utils import VideoInput
|
||||
from .image_processing_glm4v import smart_resize
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from ...image_utils import pil_torch_interpolation_mapping
|
||||
|
||||
if is_torchvision_v2_available():
|
||||
from torchvision.transforms.v2 import functional as F
|
||||
else:
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Glm4vFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
|
||||
"""
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to 2):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to 2):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
"""
|
||||
|
||||
patch_size: Optional[int]
|
||||
temporal_patch_size: Optional[int]
|
||||
merge_size: Optional[int]
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
do_resize = True
|
||||
resample = PILImageResampling.BICUBIC
|
||||
size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 15000}
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
do_convert_rgb = True
|
||||
patch_size = 14
|
||||
temporal_patch_size = 2
|
||||
merge_size = 2
|
||||
valid_kwargs = Glm4vFastImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_grid_thw"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Glm4vFastImageProcessorKwargs]):
|
||||
size = kwargs.pop("size", None)
|
||||
if size is not None and ("shortest_edge" not in size or "longest_edge" not in size):
|
||||
raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.")
|
||||
else:
|
||||
size = self.size
|
||||
|
||||
super().__init__(size=size, **kwargs)
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
images: list["torch.Tensor"],
|
||||
do_resize: bool,
|
||||
size: SizeDict,
|
||||
interpolation: Optional["F.InterpolationMode"],
|
||||
do_rescale: bool,
|
||||
rescale_factor: float,
|
||||
do_normalize: bool,
|
||||
image_mean: Optional[Union[float, list[float]]],
|
||||
image_std: Optional[Union[float, list[float]]],
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
do_convert_rgb: bool,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
):
|
||||
"""
|
||||
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
||||
vision_info (`List[Dict]`, *optional*):
|
||||
Optional list of dictionaries containing additional information about vision inputs.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. `shortest_edge` and `longest_edge` keys must be present.
|
||||
interpolation (`InterpolationMode`):
|
||||
Resampling filter to use if resizing the image.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Scale factor to use if rescaling the image.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
||||
patch_size (`int`, *optional*, defaults to `self.patch_size`):
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `self.temporal_patch_size`):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to `self.merge_size`):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
device (`torch.device`, *optional*):
|
||||
The device to process the images on. If unset, the device is inferred from the input images.
|
||||
"""
|
||||
images = self._prepare_input_images(
|
||||
images=images,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
|
||||
height, width = get_image_size(images[0], channel_dim=ChannelDimension.FIRST)
|
||||
resized_height, resized_width = height, width
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
num_frames=temporal_patch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
temporal_factor=temporal_patch_size,
|
||||
factor=patch_size * merge_size,
|
||||
)
|
||||
stacked_images = F.resize(
|
||||
stacked_images, size=(resized_height, resized_width), interpolation=interpolation
|
||||
)
|
||||
resized_images_grouped[shape] = stacked_images
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
stacked_images = self.rescale_and_normalize(
|
||||
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
processed_images_grouped[shape] = stacked_images
|
||||
|
||||
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
||||
patches = torch.stack(processed_images, dim=0)
|
||||
if patches.shape[0] % temporal_patch_size != 0:
|
||||
repeats = patches[-1].unsqueeze(0).repeat(temporal_patch_size - 1, 1, 1, 1)
|
||||
patches = torch.cat([patches, repeats], dim=0)
|
||||
|
||||
channel = patches.shape[1]
|
||||
grid_t = patches.shape[0] // temporal_patch_size
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
|
||||
patches = patches.view(
|
||||
grid_t,
|
||||
temporal_patch_size,
|
||||
channel,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
|
||||
return flatten_patches, (grid_t, grid_h, grid_w)
|
||||
|
||||
@auto_docstring
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
videos: VideoInput = None,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[dict[str, int]] = None,
|
||||
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
temporal_patch_size: Optional[int] = None,
|
||||
merge_size: Optional[int] = None,
|
||||
do_convert_rgb: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spatial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to 2):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to 2):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
"""
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
temporal_patch_size = temporal_patch_size if temporal_patch_size is not None else self.temporal_patch_size
|
||||
merge_size = merge_size if merge_size is not None else self.merge_size
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
# Make hashable for cache
|
||||
size = SizeDict(**size) if size is not None else None
|
||||
image_mean = tuple(image_mean) if image_mean is not None else None
|
||||
image_std = tuple(image_std) if image_std is not None else None
|
||||
|
||||
self._validate_preprocess_kwargs(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
return_tensors=return_tensors,
|
||||
data_format=data_format,
|
||||
)
|
||||
interpolation = (
|
||||
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
images = make_flat_list_of_images(images)
|
||||
|
||||
if images is not None and not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
data = {}
|
||||
if images is not None:
|
||||
pixel_values, vision_grid_thws = [], []
|
||||
for image in images:
|
||||
patches, image_grid_thw = self._preprocess(
|
||||
image,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
interpolation=interpolation,
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
patch_size=patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
merge_size=merge_size,
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
)
|
||||
pixel_values.extend(patches)
|
||||
vision_grid_thws.append(image_grid_thw)
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
vision_grid_thws = torch.tensor(vision_grid_thws)
|
||||
data.update({"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws})
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
|
||||
"""
|
||||
A utility that returns number of image patches for a given image size.
|
||||
|
||||
Args:
|
||||
height (`int`):
|
||||
Height of the input image.
|
||||
width (`int`):
|
||||
Width of the input image.
|
||||
images_kwargs (`dict`, *optional*)
|
||||
Any kwargs to override defaults of the image processor.
|
||||
Returns:
|
||||
`int`: Number of image patches per image.
|
||||
"""
|
||||
patch_size = images_kwargs.get("patch_size", None) or self.patch_size
|
||||
merge_size = images_kwargs.get("merge_size", None) or self.merge_size
|
||||
|
||||
factor = patch_size * merge_size
|
||||
resized_height, resized_width = smart_resize(
|
||||
t=self.temporal_patch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
factor=factor,
|
||||
t_factor=self.temporal_patch_size,
|
||||
)
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
return grid_h * grid_w
|
||||
|
||||
|
||||
__all__ = ["Glm4vImageProcessorFast"]
|
1667
src/transformers/models/glm4v/modeling_glm4v.py
Normal file
1667
src/transformers/models/glm4v/modeling_glm4v.py
Normal file
File diff suppressed because it is too large
Load Diff
1733
src/transformers/models/glm4v/modular_glm4v.py
Normal file
1733
src/transformers/models/glm4v/modular_glm4v.py
Normal file
File diff suppressed because it is too large
Load Diff
289
src/transformers/models/glm4v/processing_glm4v.py
Normal file
289
src/transformers/models/glm4v/processing_glm4v.py
Normal file
@ -0,0 +1,289 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/glm4v/modular_glm4v.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_glm4v.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...video_utils import VideoInput
|
||||
|
||||
|
||||
class Glm4vVideosProcessorKwargs(VideosKwargs, total=False):
|
||||
fps: Union[list[float], float]
|
||||
|
||||
|
||||
class Glm4vImagesKwargs(ImagesKwargs):
|
||||
patch_size: Optional[int]
|
||||
temporal_patch_size: Optional[int]
|
||||
merge_size: Optional[int]
|
||||
|
||||
|
||||
class Glm4vProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: Glm4vImagesKwargs
|
||||
videos_kwargs: Glm4vVideosProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Glm4vProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a GLM-4V processor which wraps a GLM-4V image processor and a GLM-4 tokenizer into a single processor.
|
||||
[`~Glm4vProcessor.__call__`] and [`~Glm4vProcessor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`Glm4vProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`PreTrainedTokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
video_processor ([`Glm4vVideoProcessor`], *optional*):
|
||||
The video processor is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
video_processor_class = "AutoVideoProcessor"
|
||||
|
||||
tokenizer_class = ("PreTrainedTokenizer", "PreTrainedTokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
|
||||
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
|
||||
self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
|
||||
self.video_token = "<|video|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
|
||||
self.image_token_id = (
|
||||
tokenizer.image_token_id
|
||||
if getattr(tokenizer, "image_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
)
|
||||
self.video_token_id = (
|
||||
tokenizer.video_token_id
|
||||
if getattr(tokenizer, "video_token_id", None)
|
||||
else tokenizer.convert_tokens_to_ids(self.video_token)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Glm4vProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
||||
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Glm4vProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
timestamps = videos_inputs.pop("timestamps")
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
else:
|
||||
videos_inputs = {}
|
||||
timestamps = []
|
||||
video_grid_thw = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text = text.copy() # below lines change text in-place
|
||||
if image_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.image_token in text[i]:
|
||||
num_image_tokens = image_grid_thw[index].prod() // merge_length
|
||||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
if video_grid_thw is not None:
|
||||
merge_length = self.video_processor.merge_size**2
|
||||
video_index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
num_frames = len(video_grid_thw)
|
||||
video_structure = ""
|
||||
|
||||
if hasattr(timestamps, "tolist"):
|
||||
timestamps_list = timestamps.tolist()[0]
|
||||
else:
|
||||
timestamps_list = timestamps[0] if isinstance(timestamps[0], list) else timestamps
|
||||
unique_timestamps = []
|
||||
for idx in range(0, len(timestamps_list)):
|
||||
unique_timestamps.append(timestamps_list[idx])
|
||||
selected_timestamps = unique_timestamps[:num_frames]
|
||||
while len(selected_timestamps) < num_frames:
|
||||
selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
|
||||
for frame_idx in range(num_frames):
|
||||
timestamp_sec = selected_timestamps[frame_idx]
|
||||
frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{timestamp_sec}"
|
||||
video_structure += frame_structure
|
||||
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
||||
video_index += 1
|
||||
|
||||
for frame_idx in range(len(video_grid_thw)):
|
||||
if self.image_token in text[i]:
|
||||
num_image_tokens = video_grid_thw[frame_idx].prod() // merge_length
|
||||
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
|
||||
|
||||
def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
|
||||
"""
|
||||
Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
|
||||
Args:
|
||||
image_sizes (`list[list[int]]`, *optional*):
|
||||
The input sizes formatted as (height, width) per each image.
|
||||
video_sizes (`list[list[int]]`, *optional*):
|
||||
The input sizes formatted as (num_frames, height, width) per each video.
|
||||
Returns:
|
||||
`MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
|
||||
input modalities, along with other useful data.
|
||||
"""
|
||||
|
||||
vision_data = {}
|
||||
if image_sizes is not None:
|
||||
images_kwargs = Glm4vProcessorKwargs._defaults.get("images_kwargs", {})
|
||||
images_kwargs.update(kwargs)
|
||||
merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
|
||||
|
||||
num_image_patches = [
|
||||
self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
|
||||
vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
|
||||
|
||||
if video_sizes is not None:
|
||||
videos_kwargs = Glm4vProcessorKwargs._defaults.get("videos_kwargs", {})
|
||||
videos_kwargs.update(kwargs)
|
||||
num_video_patches = [
|
||||
self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
|
||||
for video_size in video_sizes
|
||||
]
|
||||
num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
|
||||
vision_data["num_video_tokens"] = num_video_tokens
|
||||
|
||||
return MultiModalData(**vision_data)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(
|
||||
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
|
||||
):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
|
||||
**kwargs:
|
||||
Additional arguments to be passed to the tokenizer's `batch_decode method`.
|
||||
|
||||
Returns:
|
||||
`list[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
||||
return names_from_processor + ["second_per_grid_ts"]
|
||||
|
||||
|
||||
__all__ = ["Glm4vProcessor"]
|
262
src/transformers/models/glm4v/video_processing_glm4v.py
Normal file
262
src/transformers/models/glm4v/video_processing_glm4v.py
Normal file
@ -0,0 +1,262 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""video processor class for GLM-4.1V."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import (
|
||||
BatchFeature,
|
||||
)
|
||||
from ...image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
SizeDict,
|
||||
get_image_size,
|
||||
)
|
||||
from ...processing_utils import Unpack, VideosKwargs
|
||||
from ...utils import (
|
||||
TensorType,
|
||||
add_start_docstrings,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
from .image_processing_glm4v import smart_resize
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ...utils.import_utils import requires
|
||||
from ...video_processing_utils import (
|
||||
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
||||
BaseVideoProcessor,
|
||||
)
|
||||
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from ...image_utils import PILImageResampling
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Glm4vVideoProcessorInitKwargs(VideosKwargs):
|
||||
max_image_size: dict[str, int] = None
|
||||
patch_size: Optional[int] = None
|
||||
temporal_patch_size: Optional[int] = None
|
||||
merge_size: Optional[int] = None
|
||||
image_mean: Optional[list[float]] = None
|
||||
image_std: Optional[list[float]] = None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Constructs a fast GLM-4V image processor that dynamically resizes videos based on the original videos.",
|
||||
BASE_VIDEO_PROCESSOR_DOCSTRING,
|
||||
"""
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The spacial patch size of the vision encoder.
|
||||
temporal_patch_size (`int`, *optional*, defaults to 2):
|
||||
The temporal patch size of the vision encoder.
|
||||
merge_size (`int`, *optional*, defaults to 2):
|
||||
The merge size of the vision encoder to llm encoder.
|
||||
""",
|
||||
)
|
||||
@requires(backends=("torchvision",))
|
||||
class Glm4vVideoProcessor(BaseVideoProcessor):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
size = {"shortest_edge": 112 * 112, "longest_edge": 28 * 28 * 2 * 30000}
|
||||
max_image_size = {"longest_edge": 28 * 28 * 2 * 30000}
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
do_sample_frames = True
|
||||
patch_size = 14
|
||||
temporal_patch_size = 2
|
||||
max_duration = 300
|
||||
merge_size = 2
|
||||
valid_kwargs = Glm4vVideoProcessorInitKwargs
|
||||
num_frames = 16
|
||||
fps = 2
|
||||
|
||||
model_input_names = ["pixel_values_videos", "video_grid_thw"]
|
||||
|
||||
def __init__(self, **kwargs: Unpack[Glm4vVideoProcessorInitKwargs]):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def sample_frames(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
metadata: Union[VideoMetadata, dict],
|
||||
):
|
||||
total_frames = video.shape[0]
|
||||
video_fps = getattr(metadata, "fps", 2.0)
|
||||
meta_frames = getattr(metadata, "total_num_frames", total_frames)
|
||||
max_frame_idx = meta_frames - 1
|
||||
duration = getattr(metadata, "duration", None)
|
||||
if duration is None:
|
||||
duration = round(max_frame_idx / video_fps) + 1
|
||||
|
||||
if duration <= self.max_duration:
|
||||
n = int(math.floor(duration * self.fps))
|
||||
frame_indices = [min(max_frame_idx, int(math.ceil(i * video_fps / self.fps))) for i in range(n)]
|
||||
else:
|
||||
num_samples = int(self.max_duration * self.fps)
|
||||
if num_samples >= meta_frames:
|
||||
frame_indices = list(range(meta_frames))
|
||||
else:
|
||||
target_seconds = np.linspace(0, duration, num_samples, endpoint=True)
|
||||
frame_indices = [min(max_frame_idx, int(math.ceil(t * video_fps))) for t in target_seconds]
|
||||
|
||||
seen, uniq = set(), []
|
||||
for idx in frame_indices:
|
||||
if idx not in seen:
|
||||
seen.add(idx)
|
||||
uniq.append(idx)
|
||||
|
||||
if len(uniq) & 1:
|
||||
uniq.append(uniq[-1])
|
||||
|
||||
frame_indices = uniq
|
||||
sampled_video = video[frame_indices]
|
||||
full_second_idxs = [int(idx / video_fps) for idx in frame_indices]
|
||||
second_idxs = full_second_idxs[::2] # mrope
|
||||
return sampled_video, second_idxs
|
||||
|
||||
def _preprocess(
|
||||
self,
|
||||
videos: list[torch.Tensor],
|
||||
video_metadata: Optional[Union[list[VideoMetadata], list[dict]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
do_resize: bool = True,
|
||||
size: SizeDict = None,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: float = 1 / 255.0,
|
||||
do_normalize: bool = True,
|
||||
do_sample_frames: bool = True,
|
||||
image_mean: Optional[Union[float, list[float]]] = None,
|
||||
image_std: Optional[Union[float, list[float]]] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
temporal_patch_size: Optional[int] = None,
|
||||
merge_size: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
timestamps_list = []
|
||||
if do_sample_frames:
|
||||
if video_metadata is None or (isinstance(video_metadata, list) and video_metadata[0] is None):
|
||||
raise ValueError(
|
||||
"Frame sampling is enabled but no video metadata was found. "
|
||||
"Please pass in `VideoMetadata` object per each input video or set `do_sample_frames=False`"
|
||||
)
|
||||
processed_videos = []
|
||||
for video, metadata in zip(videos, video_metadata):
|
||||
video, timestamps = self.sample_frames(video, metadata)
|
||||
timestamps_list.append(timestamps)
|
||||
processed_videos.append(video)
|
||||
else:
|
||||
raise AssertionError("Must set `do_sample_frames=True` to sample frames from GLM-4.1V Model.")
|
||||
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
|
||||
resized_videos_grouped = {}
|
||||
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
B, T, C, H, W = stacked_videos.shape
|
||||
num_frames, height, width = T, H, W
|
||||
if do_resize:
|
||||
resized_height, resized_width = smart_resize(
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
temporal_factor=temporal_patch_size,
|
||||
factor=patch_size * merge_size,
|
||||
max_pixels=self.max_image_size["longest_edge"],
|
||||
)
|
||||
stacked_videos = stacked_videos.view(B * T, C, H, W)
|
||||
stacked_videos = F.interpolate(
|
||||
stacked_videos, size=(resized_height, resized_width), mode="bicubic", align_corners=False
|
||||
)
|
||||
stacked_videos = stacked_videos.view(B, T, C, resized_height, resized_width)
|
||||
resized_videos_grouped[shape] = stacked_videos
|
||||
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
|
||||
|
||||
# Group videos by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns videos with different sizes
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
|
||||
processed_videos_grouped = {}
|
||||
processed_grids = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
resized_height, resized_width = get_image_size(stacked_videos[0], channel_dim=ChannelDimension.FIRST)
|
||||
|
||||
# Fused rescale and normalize
|
||||
stacked_videos = self.rescale_and_normalize(
|
||||
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
|
||||
)
|
||||
patches = stacked_videos
|
||||
|
||||
# Check that videos have `num_frames` divisible by `temporal_patch_size`
|
||||
if patches.shape[1] % temporal_patch_size != 0:
|
||||
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
|
||||
patches = torch.cat([patches, repeats], dim=1)
|
||||
batch_size, grid_t, channel = patches.shape[:3]
|
||||
grid_t = grid_t // temporal_patch_size
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
|
||||
patches = patches.view(
|
||||
batch_size,
|
||||
grid_t,
|
||||
temporal_patch_size,
|
||||
channel,
|
||||
grid_h // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
grid_w // merge_size,
|
||||
merge_size,
|
||||
patch_size,
|
||||
)
|
||||
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
|
||||
flatten_patches = patches.reshape(
|
||||
batch_size,
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * temporal_patch_size * patch_size * patch_size,
|
||||
)
|
||||
|
||||
processed_videos_grouped[shape] = flatten_patches
|
||||
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
|
||||
|
||||
processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
|
||||
processed_grids = reorder_videos(processed_grids, grouped_videos_index)
|
||||
pixel_values_videos = torch.cat(processed_videos, dim=0)
|
||||
video_grid_thw = torch.tensor(processed_grids)
|
||||
total_frames = video_grid_thw[0][0].item()
|
||||
h = video_grid_thw[0][1].item()
|
||||
w = video_grid_thw[0][2].item()
|
||||
video_grid_thw = [[1, h, w] for _ in range(total_frames)]
|
||||
data = {
|
||||
"pixel_values_videos": pixel_values_videos,
|
||||
"video_grid_thw": video_grid_thw,
|
||||
"timestamps": timestamps_list,
|
||||
}
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["Glm4vVideoProcessor"]
|
@ -1163,6 +1163,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
@ -1208,25 +1209,26 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
||||
torch.cuda.set_device(self.transformer.first_device)
|
||||
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Flatten the tokens
|
||||
loss = self.loss_function(
|
||||
lm_logits,
|
||||
logits,
|
||||
labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
output = (logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
logits=logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
|
@ -292,6 +292,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTNeoXLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -305,6 +305,7 @@ class GranitePreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GraniteDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -320,6 +320,7 @@ class HeliumPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["HeliumDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -94,12 +94,18 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
|
||||
fps: Optional[int] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
device: Optional["torch.Tensor"] = None,
|
||||
) -> BatchFeature:
|
||||
if do_sample_frames:
|
||||
videos = [
|
||||
self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata)
|
||||
]
|
||||
|
||||
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||
# moving the whole video incurs high GPU mem usage for long videos
|
||||
if device is not None:
|
||||
videos = [video.to(device) for video in videos]
|
||||
|
||||
# Group videos by size for batched resizing
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||
resized_videos_grouped = {}
|
||||
|
@ -147,6 +147,7 @@ class InternVLVideoProcessor(BaseVideoProcessor):
|
||||
num_frames: Optional[int] = None,
|
||||
initial_shift: Optional[Union[bool, float, int]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
device: Optional["torch.Tensor"] = None,
|
||||
) -> BatchFeature:
|
||||
if do_sample_frames:
|
||||
# Sample video frames
|
||||
@ -155,6 +156,11 @@ class InternVLVideoProcessor(BaseVideoProcessor):
|
||||
for video, metadata in zip(videos, video_metadata)
|
||||
]
|
||||
|
||||
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||
# moving the whole video incurs high GPU mem usage for long videos
|
||||
if device is not None:
|
||||
videos = [video.to(device) for video in videos]
|
||||
|
||||
# Group videos by size for batched resizing
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||
resized_videos_grouped = {}
|
||||
|
@ -28,7 +28,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
||||
2.6b-en model.
|
||||
|
||||
e.g. [kyutai/stt-2.6b-en](https://huggingface.co/kyutai/stt-2.6b-en)
|
||||
e.g. [kyutai/stt-2.6b-en-trfs](https://huggingface.co/kyutai/stt-2.6b-en-trfs)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
@ -110,8 +110,7 @@ class KyutaiSpeechToTextConfig(PretrainedConfig):
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
# not the best naming here for `model_type`, but original codebase already uses model type:`stt` for in the config so we keep it to simplify
|
||||
model_type = "stt"
|
||||
model_type = "kyutai_speech_to_text"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
sub_configs = {"codec_config": AutoConfig}
|
||||
|
@ -190,7 +190,14 @@ def write_model(
|
||||
print("Converting the model.")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
config = KyutaiSpeechToTextConfig()
|
||||
config = KyutaiSpeechToTextConfig(
|
||||
vocab_size=8001,
|
||||
max_position_embeddings=375,
|
||||
num_hidden_layers=16,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
head_dim=128,
|
||||
)
|
||||
config.use_cache = True
|
||||
config.codec_config.sliding_window = 250
|
||||
|
@ -1,5 +1,5 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
|
||||
# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
@ -1,5 +1,5 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/stt/modular_kyutai_speech_to_text.py.
|
||||
# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
|
||||
@ -713,7 +713,7 @@ class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention):
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
STT_ATTENTION_CLASSES = {
|
||||
KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES = {
|
||||
"eager": KyutaiSpeechToTextAttention,
|
||||
"flash_attention_2": KyutaiSpeechToTextFlashAttention2,
|
||||
"sdpa": KyutaiSpeechToTextSdpaAttention,
|
||||
@ -726,7 +726,7 @@ class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.use_flexible_linear = use_flexible_linear
|
||||
|
||||
self.self_attn = STT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
self.self_attn = KYUTAI_SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
|
||||
)
|
||||
|
||||
@ -1169,7 +1169,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
||||
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model_id = "kyutai/stt-2.6b-en"
|
||||
>>> model_id = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
@ -278,7 +278,7 @@ class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMix
|
||||
>>> from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
|
||||
|
||||
>>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
>>> model_id = "kyutai/stt-2.6b-en"
|
||||
>>> model_id = "kyutai/stt-2.6b-en-trfs"
|
||||
|
||||
>>> processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
|
||||
>>> model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
|
@ -1212,7 +1212,7 @@ class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
|
||||
>>> model = LayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="1e3ebac")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> question = "what's his name?"
|
||||
>>> words = example["words"]
|
||||
|
@ -1601,7 +1601,7 @@ class TFLayoutLMForQuestionAnswering(TFLayoutLMPreTrainedModel, TFQuestionAnswer
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
|
||||
>>> model = TFLayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="1e3ebac")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> question = "what's his name?"
|
||||
>>> words = example["words"]
|
||||
|
@ -753,9 +753,8 @@ class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
|
||||
>>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
||||
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa", trust_remote_code=True)
|
||||
>>> image_path = dataset["test"][0]["file"]
|
||||
>>> image = Image.open(image_path).convert("RGB")
|
||||
>>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
|
||||
>>> image = dataset["test"][0]["image"]
|
||||
|
||||
>>> encoding = processor(image, return_tensors="pt")
|
||||
|
||||
@ -943,7 +942,7 @@ class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
|
||||
|
||||
>>> set_seed(0)
|
||||
|
||||
>>> dataset = load_dataset("aharley/rvl_cdip", split="train", streaming=True, trust_remote_code=True)
|
||||
>>> dataset = load_dataset("aharley/rvl_cdip", split="train", streaming=True)
|
||||
>>> data = next(iter(dataset))
|
||||
>>> image = data["image"].convert("RGB")
|
||||
|
||||
@ -1145,7 +1144,7 @@ class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel):
|
||||
|
||||
>>> set_seed(0)
|
||||
|
||||
>>> datasets = load_dataset("nielsr/funsd", split="test", trust_remote_code=True)
|
||||
>>> datasets = load_dataset("nielsr/funsd", split="test")
|
||||
>>> labels = datasets.features["ner_tags"].feature.names
|
||||
>>> id2label = {v: k for v, k in enumerate(labels)}
|
||||
|
||||
@ -1302,9 +1301,8 @@ class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel):
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
||||
>>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa", trust_remote_code=True)
|
||||
>>> image_path = dataset["test"][0]["file"]
|
||||
>>> image = Image.open(image_path).convert("RGB")
|
||||
>>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
|
||||
>>> image = dataset["test"][0]["image"]
|
||||
>>> question = "When is coffee break?"
|
||||
>>> encoding = processor(image, question, return_tensors="pt")
|
||||
|
||||
|
@ -736,7 +736,7 @@ class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
||||
>>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> image = example["image"]
|
||||
>>> words = example["tokens"]
|
||||
@ -951,7 +951,7 @@ class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
||||
>>> model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7)
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> image = example["image"]
|
||||
>>> words = example["tokens"]
|
||||
@ -1052,7 +1052,7 @@ class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
||||
>>> model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> image = example["image"]
|
||||
>>> question = "what's his name?"
|
||||
@ -1172,7 +1172,7 @@ class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
||||
>>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> image = example["image"]
|
||||
>>> words = example["tokens"]
|
||||
|
@ -1296,7 +1296,7 @@ class TFLayoutLMv3Model(TFLayoutLMv3PreTrainedModel):
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
||||
>>> model = TFAutoModel.from_pretrained("microsoft/layoutlmv3-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> image = example["image"]
|
||||
>>> words = example["tokens"]
|
||||
@ -1439,7 +1439,7 @@ class TFLayoutLMv3ForSequenceClassification(TFLayoutLMv3PreTrainedModel, TFSeque
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
||||
>>> model = TFAutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> image = example["image"]
|
||||
>>> words = example["tokens"]
|
||||
@ -1566,7 +1566,7 @@ class TFLayoutLMv3ForTokenClassification(TFLayoutLMv3PreTrainedModel, TFTokenCla
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
||||
>>> model = TFAutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7)
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> image = example["image"]
|
||||
>>> words = example["tokens"]
|
||||
@ -1703,7 +1703,7 @@ class TFLayoutLMv3ForQuestionAnswering(TFLayoutLMv3PreTrainedModel, TFQuestionAn
|
||||
>>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
|
||||
>>> model = TFAutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> image = example["image"]
|
||||
>>> question = "what's his name?"
|
||||
|
@ -516,16 +516,15 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
|
||||
|
||||
self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
|
||||
|
||||
self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
|
||||
self.descriptor_dim = config.descriptor_dim
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.filter_threshold = config.filter_threshold
|
||||
self.depth_confidence = config.depth_confidence
|
||||
self.width_confidence = config.width_confidence
|
||||
|
||||
if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim:
|
||||
self.input_projection = nn.Linear(
|
||||
config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True
|
||||
)
|
||||
if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
|
||||
self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
|
||||
else:
|
||||
self.input_projection = nn.Identity()
|
||||
|
||||
@ -721,7 +720,7 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
|
||||
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
|
||||
keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
|
||||
mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
|
||||
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim)
|
||||
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
|
||||
image_indices = torch.arange(batch_size * 2, device=device)
|
||||
# Keypoint normalization
|
||||
keypoints = normalize_keypoints(keypoints, height, width)
|
||||
@ -892,7 +891,7 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
|
||||
|
||||
keypoints, _, descriptors, mask = keypoint_detections[:4]
|
||||
keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
|
||||
descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values)
|
||||
descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
|
||||
mask = mask.reshape(batch_size, 2, -1)
|
||||
|
||||
absolute_keypoints = keypoints.clone()
|
||||
|
@ -587,16 +587,15 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
|
||||
|
||||
self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
|
||||
|
||||
self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
|
||||
self.descriptor_dim = config.descriptor_dim
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.filter_threshold = config.filter_threshold
|
||||
self.depth_confidence = config.depth_confidence
|
||||
self.width_confidence = config.width_confidence
|
||||
|
||||
if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim:
|
||||
self.input_projection = nn.Linear(
|
||||
config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True
|
||||
)
|
||||
if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
|
||||
self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
|
||||
else:
|
||||
self.input_projection = nn.Identity()
|
||||
|
||||
@ -792,7 +791,7 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
|
||||
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
|
||||
keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
|
||||
mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
|
||||
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim)
|
||||
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
|
||||
image_indices = torch.arange(batch_size * 2, device=device)
|
||||
# Keypoint normalization
|
||||
keypoints = normalize_keypoints(keypoints, height, width)
|
||||
@ -963,7 +962,7 @@ class LightGlueForKeypointMatching(LightGluePreTrainedModel):
|
||||
|
||||
keypoints, _, descriptors, mask = keypoint_detections[:4]
|
||||
keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
|
||||
descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values)
|
||||
descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
|
||||
mask = mask.reshape(batch_size, 2, -1)
|
||||
|
||||
absolute_keypoints = keypoints.clone()
|
||||
|
@ -644,7 +644,7 @@ class LiltModel(LiltPreTrainedModel):
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
|
||||
>>> model = AutoModel.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> words = example["tokens"]
|
||||
>>> boxes = example["bboxes"]
|
||||
@ -784,7 +784,7 @@ class LiltForSequenceClassification(LiltPreTrainedModel):
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
|
||||
>>> model = AutoModelForSequenceClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> words = example["tokens"]
|
||||
>>> boxes = example["bboxes"]
|
||||
@ -899,7 +899,7 @@ class LiltForTokenClassification(LiltPreTrainedModel):
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
|
||||
>>> model = AutoModelForTokenClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> words = example["tokens"]
|
||||
>>> boxes = example["bboxes"]
|
||||
@ -1016,7 +1016,7 @@ class LiltForQuestionAnswering(LiltPreTrainedModel):
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
|
||||
>>> model = AutoModelForQuestionAnswering.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
|
||||
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
|
||||
>>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
|
||||
>>> example = dataset[0]
|
||||
>>> words = example["tokens"]
|
||||
>>> boxes = example["bboxes"]
|
||||
|
@ -320,6 +320,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["LlamaDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -104,7 +103,6 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
language_code_re = re.compile(">>.+<<") # type: re.Pattern
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -186,9 +184,11 @@ class MarianTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def remove_language_code(self, text: str):
|
||||
"""Remove language codes like >>fr<< before sentencepiece"""
|
||||
match = self.language_code_re.match(text)
|
||||
code: list = [match.group(0)] if match else []
|
||||
return code, self.language_code_re.sub("", text)
|
||||
code = []
|
||||
if text.startswith(">>") and (end_loc := text.find("<<")) != -1:
|
||||
code.append(text[: end_loc + 2])
|
||||
text = text[end_loc + 2 :]
|
||||
return code, text
|
||||
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
code, text = self.remove_language_code(text)
|
||||
|
@ -172,6 +172,8 @@ class MimiEncoderOutput(ModelOutput):
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
|
||||
have their past key value states given to this model).
|
||||
padding_cache (<fill_type>):
|
||||
<fill_docstring>
|
||||
"""
|
||||
|
||||
audio_codes: Optional[torch.LongTensor] = None
|
||||
|
@ -590,6 +590,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MiniMaxDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -262,6 +262,7 @@ class MistralPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MistralDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -417,6 +417,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MixtralDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -1071,7 +1071,7 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
|
||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
||||
|
@ -1201,7 +1201,7 @@ class ModernBertForMaskedLM(ModernBertPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
|
||||
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
|
||||
|
@ -301,6 +301,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["OlmoDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -305,6 +305,7 @@ class Olmo2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Olmo2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -132,6 +132,8 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
class PaliGemmaModel(PaliGemmaPreTrainedModel):
|
||||
_checkpoint_conversion_mapping = {"language_model.model": "language_model"}
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: PaliGemmaConfig):
|
||||
super().__init__(config)
|
||||
|
@ -295,6 +295,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PhiDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -316,6 +316,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Phi3DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -1622,6 +1622,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Phi4MultimodalDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -266,6 +266,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Qwen2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -213,6 +213,7 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
||||
min_frames: Optional[int] = None,
|
||||
max_frames: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
device: Optional["torch.Tensor"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if do_sample_frames:
|
||||
@ -230,6 +231,11 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor):
|
||||
for video, metadata in zip(videos, video_metadata)
|
||||
]
|
||||
|
||||
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||
# moving the whole video incurs high GPU mem usage for long videos
|
||||
if device is not None:
|
||||
videos = [video.to(device) for video in videos]
|
||||
|
||||
# Group videos by size for batched resizing
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
|
||||
resized_videos_grouped = {}
|
||||
|
@ -292,6 +292,7 @@ class Qwen3PreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Qwen3DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
@ -424,6 +424,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Qwen3MoeDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
27
src/transformers/models/smollm3/__init__.py
Normal file
27
src/transformers/models/smollm3/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_smollm3 import *
|
||||
from .modeling_smollm3 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
245
src/transformers/models/smollm3/configuration_smollm3.py
Normal file
245
src/transformers/models/smollm3/configuration_smollm3.py
Normal file
@ -0,0 +1,245 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/smollm3/modular_smollm3.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_smollm3.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class SmolLM3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`SmolLM3Model`]. It is used to instantiate a
|
||||
SmolLM3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the SmolLM3 3B.
|
||||
e.g. [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 128256):
|
||||
Vocabulary size of the SmolLM3 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`SmolLM3Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 36):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `16`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 128004):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 128000):
|
||||
The id of the beginning of sentence token.
|
||||
eos_token_id (`int`, *optional*, defaults to 128001):
|
||||
The id of the end of sentence token.
|
||||
rope_theta (`float`, *optional*, defaults to 2000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `None`.
|
||||
no_rope_layers (`List[int]`, *optional*):
|
||||
List with at least the same length as the number of layers in the model.
|
||||
A `1` at an index position indicates that the corresponding layer will use RoPE,
|
||||
while a `0` indicates that it's a NoPE layer.
|
||||
no_rope_layer_interval (`int`, *optional*, defaults to 4):
|
||||
If `no_rope_layers` is `None`, it will be created using a NoPE layer every
|
||||
`no_rope_layer_interval` layers.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer. Automatically computed based on sliding window and NoPE settings.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import SmolLM3Model, SmolLM3Config
|
||||
|
||||
>>> # Initializing a SmolLM3 style configuration
|
||||
>>> configuration = SmolLM3Config()
|
||||
|
||||
>>> # Initializing a model from the SmolLM3 style configuration
|
||||
>>> model = SmolLM3Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "smollm3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=128256,
|
||||
hidden_size=2048,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=36,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=128004,
|
||||
bos_token_id=128000,
|
||||
eos_token_id=128001,
|
||||
rope_theta=2000000.0,
|
||||
rope_scaling=None,
|
||||
use_sliding_window=False,
|
||||
sliding_window=None,
|
||||
no_rope_layers=None,
|
||||
no_rope_layer_interval=4,
|
||||
layer_types=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
if no_rope_layers is None:
|
||||
self.no_rope_layers = [
|
||||
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(num_hidden_layers)
|
||||
]
|
||||
else:
|
||||
self.no_rope_layers = no_rope_layers
|
||||
|
||||
self.no_rope_layer_interval = no_rope_layer_interval
|
||||
|
||||
# Update layer_types based on sliding window and NoPE pattern
|
||||
if layer_types is None:
|
||||
layer_types = []
|
||||
for layer_idx in range(num_hidden_layers):
|
||||
has_rope = self.no_rope_layers[layer_idx]
|
||||
if use_sliding_window and sliding_window is not None and not has_rope:
|
||||
layer_types.append("sliding_attention")
|
||||
else:
|
||||
layer_types.append("full_attention")
|
||||
|
||||
self.layer_types = layer_types
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
__all__ = ["SmolLM3Config"]
|
845
src/transformers/models/smollm3/modeling_smollm3.py
Normal file
845
src/transformers/models/smollm3/modeling_smollm3.py
Normal file
@ -0,0 +1,845 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/smollm3/modular_smollm3.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_smollm3.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging
|
||||
from .configuration_smollm3 import SmolLM3Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class SmolLM3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: SmolLM3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
self.use_rope = config.no_rope_layers[layer_idx]
|
||||
self.sliding_window = (
|
||||
config.sliding_window
|
||||
if config.use_sliding_window and config.layer_types[layer_idx] == "sliding_attention"
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if self.use_rope:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class SmolLM3RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
SmolLM3RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SmolLM3PreTrainedModel(PreTrainedModel):
|
||||
config_class = SmolLM3Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SmolLM3DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_3 = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, SmolLM3RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class SmolLM3MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class SmolLM3DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: SmolLM3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = SmolLM3Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = SmolLM3MLP(config)
|
||||
self.input_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.attention_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SmolLM3RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: SmolLM3Config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SmolLM3Model(SmolLM3PreTrainedModel):
|
||||
def __init__(self, config: SmolLM3Config):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[SmolLM3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = SmolLM3RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
|
||||
if not isinstance(past_key_values, (type(None), Cache)):
|
||||
raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
mask_kwargs = {
|
||||
"config": self.config,
|
||||
"input_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
"full_attention": create_causal_mask(**mask_kwargs),
|
||||
}
|
||||
# The sliding window alternating layers are not always activated depending on the config
|
||||
if self.has_sliding_layers:
|
||||
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SmolLM3ForCausalLM(SmolLM3PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = SmolLM3Model(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, SmolLM3ForCausalLM
|
||||
|
||||
>>> model = SmolLM3ForCausalLM.from_pretrained("meta-smollm3/SmolLM3-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-smollm3/SmolLM3-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The SmolLM3 Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`SmolLM3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
"""
|
||||
)
|
||||
class SmolLM3ForSequenceClassification(SmolLM3PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = SmolLM3Model(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
transformer_outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SmolLM3ForTokenClassification(SmolLM3PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = SmolLM3Model(config)
|
||||
if getattr(config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif getattr(config, "hidden_dropout", None) is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) -> TokenClassifierOutput:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.score(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.config)
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SmolLM3ForQuestionAnswering(SmolLM3PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = SmolLM3Model(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.transformer.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.transformer.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
start_positions: Optional[torch.LongTensor] = None,
|
||||
end_positions: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> QuestionAnsweringModelOutput:
|
||||
outputs: BaseModelOutputWithPast = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1).contiguous()
|
||||
end_logits = end_logits.squeeze(-1).contiguous()
|
||||
|
||||
loss = None
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
||||
|
||||
return QuestionAnsweringModelOutput(
|
||||
loss=loss,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SmolLM3PreTrainedModel",
|
||||
"SmolLM3Model",
|
||||
"SmolLM3ForCausalLM",
|
||||
"SmolLM3ForSequenceClassification",
|
||||
"SmolLM3ForTokenClassification",
|
||||
"SmolLM3ForQuestionAnswering",
|
||||
]
|
350
src/transformers/models/smollm3/modular_smollm3.py
Normal file
350
src/transformers/models/smollm3/modular_smollm3.py
Normal file
@ -0,0 +1,350 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaForCausalLM,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaPreTrainedModel,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from ..qwen2.modeling_qwen2 import Qwen2Model
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SmolLM3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`SmolLM3Model`]. It is used to instantiate a
|
||||
SmolLM3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the SmolLM3 3B.
|
||||
e.g. [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 128256):
|
||||
Vocabulary size of the SmolLM3 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`SmolLM3Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 36):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `16`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*, defaults to 128004):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 128000):
|
||||
The id of the beginning of sentence token.
|
||||
eos_token_id (`int`, *optional*, defaults to 128001):
|
||||
The id of the end of sentence token.
|
||||
rope_theta (`float`, *optional*, defaults to 2000000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `None`.
|
||||
no_rope_layers (`List[int]`, *optional*):
|
||||
List with at least the same length as the number of layers in the model.
|
||||
A `1` at an index position indicates that the corresponding layer will use RoPE,
|
||||
while a `0` indicates that it's a NoPE layer.
|
||||
no_rope_layer_interval (`int`, *optional*, defaults to 4):
|
||||
If `no_rope_layers` is `None`, it will be created using a NoPE layer every
|
||||
`no_rope_layer_interval` layers.
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer. Automatically computed based on sliding window and NoPE settings.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import SmolLM3Model, SmolLM3Config
|
||||
|
||||
>>> # Initializing a SmolLM3 style configuration
|
||||
>>> configuration = SmolLM3Config()
|
||||
|
||||
>>> # Initializing a model from the SmolLM3 style configuration
|
||||
>>> model = SmolLM3Model(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "smollm3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=128256,
|
||||
hidden_size=2048,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=36,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=128004,
|
||||
bos_token_id=128000,
|
||||
eos_token_id=128001,
|
||||
rope_theta=2000000.0,
|
||||
rope_scaling=None,
|
||||
use_sliding_window=False,
|
||||
sliding_window=None,
|
||||
no_rope_layers=None,
|
||||
no_rope_layer_interval=4,
|
||||
layer_types=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
if no_rope_layers is None:
|
||||
self.no_rope_layers = [
|
||||
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(num_hidden_layers)
|
||||
]
|
||||
else:
|
||||
self.no_rope_layers = no_rope_layers
|
||||
|
||||
self.no_rope_layer_interval = no_rope_layer_interval
|
||||
|
||||
# Update layer_types based on sliding window and NoPE pattern
|
||||
if layer_types is None:
|
||||
layer_types = []
|
||||
for layer_idx in range(num_hidden_layers):
|
||||
has_rope = self.no_rope_layers[layer_idx]
|
||||
if use_sliding_window and sliding_window is not None and not has_rope:
|
||||
layer_types.append("sliding_attention")
|
||||
else:
|
||||
layer_types.append("full_attention")
|
||||
|
||||
self.layer_types = layer_types
|
||||
layer_type_validation(self.layer_types)
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
|
||||
class SmolLM3Attention(LlamaAttention):
|
||||
def __init__(self, config: SmolLM3Config, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
self.use_rope = config.no_rope_layers[layer_idx]
|
||||
self.sliding_window = (
|
||||
config.sliding_window
|
||||
if config.use_sliding_window and config.layer_types[layer_idx] == "sliding_attention"
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
if self.use_rope:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class SmolLM3PreTrainedModel(LlamaPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class SmolLM3Model(Qwen2Model):
|
||||
pass
|
||||
|
||||
|
||||
class SmolLM3ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
class SmolLM3ForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
|
||||
class SmolLM3ForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
||||
|
||||
|
||||
class SmolLM3ForQuestionAnswering(LlamaForQuestionAnswering):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SmolLM3Config",
|
||||
"SmolLM3PreTrainedModel",
|
||||
"SmolLM3Model",
|
||||
"SmolLM3ForCausalLM",
|
||||
"SmolLM3ForSequenceClassification",
|
||||
"SmolLM3ForTokenClassification",
|
||||
"SmolLM3ForQuestionAnswering",
|
||||
]
|
@ -332,6 +332,7 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
||||
num_frames: Optional[int] = None,
|
||||
skip_secs: Optional[int] = 0,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
device: Optional["torch.Tensor"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Group videos by size for batched resizing
|
||||
@ -356,6 +357,11 @@ class SmolVLMVideoProcessor(BaseVideoProcessor):
|
||||
]
|
||||
durations_list = [len(video) // 24 for video in videos]
|
||||
|
||||
# We need to sample frames first before moving to device, if `do_sample_frames=True`. Otherwise
|
||||
# moving the whole video incurs high GPU mem usage for long videos
|
||||
if device is not None:
|
||||
videos = [video.to(device) for video in videos]
|
||||
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape(processed_videos)
|
||||
resized_videos_grouped = {}
|
||||
for shape, stacked_videos in grouped_videos.items():
|
||||
|
@ -2197,7 +2197,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> dataset = load_dataset(
|
||||
... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True
|
||||
... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
|
||||
... ) # doctest: +IGNORE_RESULT
|
||||
>>> dataset = dataset.sort("id")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
@ -2878,7 +2878,7 @@ class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
|
||||
>>> import torch
|
||||
|
||||
>>> dataset = load_dataset(
|
||||
... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True
|
||||
... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
|
||||
... ) # doctest: +IGNORE_RESULT
|
||||
>>> dataset = dataset.sort("id")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user