mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
[WIP] add deepseek-v3 (#35926)
* init commit
* style
* take comments into account
* add deepseekv3 modeling
* remove redundant code
* apply make style
* apply fix-copies
* make format
* add init files
* rename deepseekv3 into deepseek_v3 based on its model_type
* rename deepseekv3 into deepseek_v3 based on its model_type
* deepseek-v3 not deepseek_v3
* set model_type as deepseek_v3
* use default docs
* apply make
* fill type and docstring
* add rope_config_validation
* use custom DeepseekV3MLP
* hold code only for checkpoints congifuration; remove redundant
* revise rope yarn for DeepSeek variation
* rename DeepSeek-V3
* some refactoring
* revise load_hook to work properly; make moe func trainable; use llama instead of mixtral
* fix attention forward
* use -1 for not-changing dim when to use exapnd
* refactor DeepseekV3TopkRouter
* use reshape_for_rope instead of load_hook; revise attention forward for TP; rename q_head_dim with qk_head_dim
* register pre_hook and hook both
* make style
* use n_shared_experts
* Update src/transformers/models/deepseek_v3/configuration_deepseek_v3.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* add test file
* update modeling_file according to modular file
* make style
* add mapping for DeepseekV3ForSequenceClassification
* remove aux_loss_alpha
* add deepseek_v3 for perf
* add deepseek_v3
* rename test as deepseekv3
* use tiny-deepseek-v3
* remove DeepseekV3ForSequenceClassification
* cache before padding
* remote output_router_logits
* Revert "remote output_router_logits"
This reverts commit f264f800d0
.
* remove output_router_logits
* make e_score_correction_bias as buffer
* skip tests not compatible
* make style
* make e_score_correction_bias as buffer
* use rope_interleave instead of load_hook
* skip tests not compatible with MLA
* add doc for rope_interleave
* fix typo
* remove torch.no_grad for selecting topk
* fix post merge issue
* mrege with main and simplify
* nits
* final
* small fixes
* fix
* support TP better
* stash
* changes currently requires
* remove synch
* more fixes for TP
* temp fix for TP : some attention layers's FP8 scales are too small + shared is local colwise and anything is local if FP8 because weights are used
* updates to have generation work!
* push most of the changes
* reorder functions + call for contributions!
* update readme
* nits
* update
* ruff was updated on main
* merge with main and fix copies
* revert unrelated changes
* route all tokens to all experts when testing to avoid no gradient iddues
* finish fixing all tests
* fixup
* nit
* clean config
* last readme changes
* nit
* do cnit
* typo
* last nit
* one more one more
---------
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: arthur@huggingface.co <arthur@ip-26-0-165-131.ec2.internal>
This commit is contained in:
parent
52cc204dd7
commit
eca74d1367
@ -415,6 +415,8 @@
|
|||||||
title: DeBERTa
|
title: DeBERTa
|
||||||
- local: model_doc/deberta-v2
|
- local: model_doc/deberta-v2
|
||||||
title: DeBERTa-v2
|
title: DeBERTa-v2
|
||||||
|
- local: model_doc/deepseek_v3
|
||||||
|
title: DeepSeek-V3
|
||||||
- local: model_doc/dialogpt
|
- local: model_doc/dialogpt
|
||||||
title: DialoGPT
|
title: DialoGPT
|
||||||
- local: model_doc/diffllama
|
- local: model_doc/diffllama
|
||||||
|
184
docs/source/en/model_doc/deepseek_v3.md
Normal file
184
docs/source/en/model_doc/deepseek_v3.md
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# DeepSeek-V3
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The DeepSeek-V3 model was proposed in [DeepSeek-V3 Technical Report](https://arxiv.org/abs/2412.19437) by DeepSeek-AI Team.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 671B total parameters with 37B activated for each token. To achieve efficient inference and cost-effective training, DeepSeek-V3 adopts Multi-head Latent Attention (MLA) and DeepSeekMoE architectures, which were thoroughly validated in DeepSeek-V2. Furthermore, DeepSeek-V3 pioneers an auxiliary-loss-free strategy for load balancing and sets a multi-token prediction training objective for stronger performance. We pre-train DeepSeek-V3 on 14.8 trillion diverse and high-quality tokens, followed by Supervised Fine-Tuning and Reinforcement Learning stages to fully harness its capabilities. Comprehensive evaluations reveal that DeepSeek-V3 outperforms other open-source models and achieves performance comparable to leading closed-source models. Despite its excellent performance, DeepSeek-V3 requires only 2.788M H800 GPU hours for its full training. In addition, its training process is remarkably stable. Throughout the entire training process, we did not experience any irrecoverable loss spikes or perform any rollbacks. The model checkpoints are available at https://github.com/deepseek-ai/DeepSeek-V3.
|
||||||
|
|
||||||
|
## Limitations and call for contribution!
|
||||||
|
|
||||||
|
We are super happy to make this code community-powered, and would love to see how you can best optimize the following:
|
||||||
|
|
||||||
|
- current implementation uses the "naive" attention compution (so not really MLA)
|
||||||
|
- current implementation loops through the experts. This should be replaced. Pointers to use `get_packed_weights` from `intetrations/tensor_parallel`.
|
||||||
|
- current implementation uses the eleuther formula for ROPE, using the orginal one would be more efficient! (should still follow our API)
|
||||||
|
- static cache is not supported (this should be just a generation config issue / config shape issues)
|
||||||
|
|
||||||
|
### Usage tips
|
||||||
|
The model uses Multi-head Latent Attention (MLA) and DeepSeekMoE architectures for efficient inference and cost-effective training. It employs an auxiliary-loss-free strategy for load balancing and multi-token prediction training objective. The model can be used for various language tasks after being pre-trained on 14.8 trillion tokens and going through Supervised Fine-Tuning and Reinforcement Learning stages.
|
||||||
|
|
||||||
|
You can run the model in `FP8` automatically, using 2 nodes of 8 H100 should be more than enough!
|
||||||
|
|
||||||
|
```python
|
||||||
|
# `run_deepseek_v1.py`
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
torch.manual_seed(30)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("deepseek-r1")
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
|
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||||
|
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("deepseek-r1", device_map="auto", torch_dtype=torch.bfloat16)
|
||||||
|
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||||||
|
import time
|
||||||
|
start = time.time()
|
||||||
|
outputs = model.generate(inputs, max_new_tokens=50)
|
||||||
|
print(tokenizer.batch_decode(outputs))
|
||||||
|
print(time.time()-start)
|
||||||
|
```
|
||||||
|
This generated:
|
||||||
|
|
||||||
|
``````
|
||||||
|
<|Assistant|><think>
|
||||||
|
Okay, the user wants to demonstrate how chat templating works. Let me break down what that means. Chat templating is about structuring the conversation data, especially for models that need specific input formats. Maybe they're referring to something like how messages are formatted with roles (user, assistant, system) in APIs like OpenAI.
|
||||||
|
|
||||||
|
First, I should explain what chat templating is. It's the process of formatting conversation data into a structured format that the model can understand. This usually includes roles and content. For example, user messages, assistant responses, and system messages each have their own role tags.
|
||||||
|
|
||||||
|
They might want an example. Let me think of a simple conversation. The user says "Hello, how are you?" and the assistant responds "I'm doing great. How can I help you today?" Then the user follows up with wanting to show off chat templating. So the example should include the history and the new message.
|
||||||
|
|
||||||
|
In some frameworks, like Hugging Face's Transformers, chat templates are applied using Jinja2 templates. The template might look something like combining system messages, then looping through user and assistant messages with appropriate tags. For instance, using {% for message in messages %} and assigning roles like <|user|>, <|assistant|>, etc.
|
||||||
|
|
||||||
|
I should structure the example with the messages array, showing each role and content. Then apply a hypothetical template to convert that into a formatted string the model uses. Also, mention that different models have different templating requirements, like using special tokens or varying role labels.
|
||||||
|
|
||||||
|
Wait, the user mentioned "chat templating" in the context of showing off. Maybe they want a practical example they can present. So providing a code snippet or a structured data example would be helpful. Let me outline a typical messages array and then the templated output.
|
||||||
|
|
||||||
|
Also, it's important to note that proper templating ensures the model knows the conversation flow, which is crucial for generating coherent responses. Maybe include a note about why it's important, like maintaining context and role-specific processing.
|
||||||
|
|
||||||
|
Let me check if there are any common mistakes or things to avoid. For example, not closing tags properly, or mismatching roles. But maybe that's too detailed unless the user asks. Focus on the positive example first.
|
||||||
|
|
||||||
|
Putting it all together, the response should have an example messages array, the applied template, and the final formatted string. Maybe use angle brackets or special tokens as placeholders. Also, mention that this helps in training or fine-tuning models with structured data.
|
||||||
|
|
||||||
|
I think that's a solid approach. Let me structure it step by step to make it clear.
|
||||||
|
</think>
|
||||||
|
|
||||||
|
Chat templating is a way to structure conversation data (e.g., user/assistant interactions) into a format that language models understand. This is especially important for models trained to handle multi-turn dialogues, where the input must explicitly separate roles (user, assistant, system, etc.) and messages. Let’s break this down with an example!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **Step 1: Raw Conversation History**
|
||||||
|
Suppose we have this conversation:
|
||||||
|
- **User**: "Hello, how are you?"
|
||||||
|
- **Assistant**: "I'm doing great. How can I help you today?"
|
||||||
|
- **User**: "I'd like to show off how chat templating works!"
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **Step 2: Structured Messages**
|
||||||
|
In frameworks like Hugging Face Transformers or OpenAI, conversations are often formatted as a list of dictionaries with `role` and `content`:
|
||||||
|
```python
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello, how are you?"},
|
||||||
|
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||||
|
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **Step 3: Apply a Chat Template**
|
||||||
|
A **chat template** converts this structured data into a single string formatted for the model. For example, using a Jinja-style template (common in Hugging Face):
|
||||||
|
|
||||||
|
```jinja
|
||||||
|
{% for message in messages %}
|
||||||
|
{% if message['role'] == 'user' %}
|
||||||
|
<|user|>{{ message['content'] }}<|end|>
|
||||||
|
{% elif message['role'] == 'assistant' %}
|
||||||
|
<|assistant|>{{ message['content'] }}<|end|>
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
<|assistant|>
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **Step 4: Final Templated Output**
|
||||||
|
Applying the template to our `messages` list would produce:
|
||||||
|
```text
|
||||||
|
<|user|>Hello, how are you?<|end|>
|
||||||
|
<|assistant|>I'm doing great. How can I help you today?<|end|>
|
||||||
|
<|user|>I'd like to show off how chat templating works!<|end|>
|
||||||
|
<|assistant|>
|
||||||
|
```
|
||||||
|
|
||||||
|
This tells the model:
|
||||||
|
1. The conversation history (user/assistant turns).
|
||||||
|
2. The model’s turn to generate a response (`<|assistant|>` at the end).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **Key Notes**:
|
||||||
|
- **Role Separation**: Tags like `<|user|>` and `<|assistant|>` help the model distinguish speakers.
|
||||||
|
- **Special Tokens**: Models often use unique tokens (e.g., `<|end|>`) to mark message boundaries.
|
||||||
|
- **Flexibility**: Templates vary by model (e.g., OpenAI uses `{"role": "user", "content": "..."}` instead of tags).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **Why This Matters**:
|
||||||
|
- **Consistency**: Ensures the model understands dialogue structure.
|
||||||
|
- **Context Preservation**: Maintains the flow of multi-turn conversations.
|
||||||
|
- **Alignment**: Matches the format the model was trained on for better performance.
|
||||||
|
|
||||||
|
Want to dive deeper or see a specific framework’s implementation (e.g., OpenAI, Llama, Mistral)? Let me know! 😊<|end▁of▁sentence|>
|
||||||
|
``````
|
||||||
|
|
||||||
|
Use the following to run it
|
||||||
|
```bash
|
||||||
|
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0|1 --rdzv-id an_id --rdzv-backend c10d --rdzv-endpoint master_addr:master_port run_deepseek_r1.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If you have:
|
||||||
|
```bash
|
||||||
|
[rank0]: ncclInternalError: Internal check failed.
|
||||||
|
[rank0]: Last error:
|
||||||
|
[rank0]: Bootstrap : no socket interface found
|
||||||
|
```
|
||||||
|
error, it means NCCL was probably not loaded.
|
||||||
|
|
||||||
|
|
||||||
|
## DeepseekV3Config
|
||||||
|
|
||||||
|
[[autodoc]] DeepseekV3Config
|
||||||
|
|
||||||
|
## DeepseekV3Model
|
||||||
|
|
||||||
|
[[autodoc]] DeepseekV3Model
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## DeepseekV3ForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] DeepseekV3ForCausalLM
|
||||||
|
- forward
|
@ -345,6 +345,7 @@ _import_structure = {
|
|||||||
],
|
],
|
||||||
"models.deberta_v2": ["DebertaV2Config"],
|
"models.deberta_v2": ["DebertaV2Config"],
|
||||||
"models.decision_transformer": ["DecisionTransformerConfig"],
|
"models.decision_transformer": ["DecisionTransformerConfig"],
|
||||||
|
"models.deepseek_v3": ["DeepseekV3Config"],
|
||||||
"models.deformable_detr": ["DeformableDetrConfig"],
|
"models.deformable_detr": ["DeformableDetrConfig"],
|
||||||
"models.deit": ["DeiTConfig"],
|
"models.deit": ["DeiTConfig"],
|
||||||
"models.deprecated": [],
|
"models.deprecated": [],
|
||||||
@ -2023,6 +2024,13 @@ else:
|
|||||||
"DecisionTransformerPreTrainedModel",
|
"DecisionTransformerPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.deepseek_v3"].extend(
|
||||||
|
[
|
||||||
|
"DeepseekV3ForCausalLM",
|
||||||
|
"DeepseekV3Model",
|
||||||
|
"DeepseekV3PreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.deformable_detr"].extend(
|
_import_structure["models.deformable_detr"].extend(
|
||||||
[
|
[
|
||||||
"DeformableDetrForObjectDetection",
|
"DeformableDetrForObjectDetection",
|
||||||
@ -5546,6 +5554,9 @@ if TYPE_CHECKING:
|
|||||||
from .models.decision_transformer import (
|
from .models.decision_transformer import (
|
||||||
DecisionTransformerConfig,
|
DecisionTransformerConfig,
|
||||||
)
|
)
|
||||||
|
from .models.deepseek_v3 import (
|
||||||
|
DeepseekV3Config,
|
||||||
|
)
|
||||||
from .models.deformable_detr import (
|
from .models.deformable_detr import (
|
||||||
DeformableDetrConfig,
|
DeformableDetrConfig,
|
||||||
)
|
)
|
||||||
@ -7175,6 +7186,11 @@ if TYPE_CHECKING:
|
|||||||
DecisionTransformerModel,
|
DecisionTransformerModel,
|
||||||
DecisionTransformerPreTrainedModel,
|
DecisionTransformerPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.deepseek_v3 import (
|
||||||
|
DeepseekV3ForCausalLM,
|
||||||
|
DeepseekV3Model,
|
||||||
|
DeepseekV3PreTrainedModel,
|
||||||
|
)
|
||||||
from .models.deformable_detr import (
|
from .models.deformable_detr import (
|
||||||
DeformableDetrForObjectDetection,
|
DeformableDetrForObjectDetection,
|
||||||
DeformableDetrModel,
|
DeformableDetrModel,
|
||||||
|
@ -291,7 +291,7 @@ def w8a8_block_fp8_matmul_compile(
|
|||||||
return output.to(output_dtype)
|
return output.to(output_dtype)
|
||||||
|
|
||||||
|
|
||||||
class FP8Linear(nn.Module):
|
class FP8Linear(nn.Linear):
|
||||||
dtype = torch.float8_e4m3fn
|
dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -304,17 +304,20 @@ class FP8Linear(nn.Module):
|
|||||||
device=None,
|
device=None,
|
||||||
activation_scheme="dynamic",
|
activation_scheme="dynamic",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(in_features, out_features)
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
|
|
||||||
self.register_buffer("weight", torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
|
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
|
||||||
|
|
||||||
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
|
if self.weight.element_size() == 1:
|
||||||
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
|
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
|
||||||
self.register_buffer(
|
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
|
||||||
"weight_scale_inv", torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
|
self.weight_scale_inv = nn.Parameter(
|
||||||
)
|
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight_scale_inv", None)
|
||||||
|
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
@ -330,11 +333,11 @@ class FP8Linear(nn.Module):
|
|||||||
return F.linear(input, self.weight, self.bias)
|
return F.linear(input, self.weight, self.bias)
|
||||||
else:
|
else:
|
||||||
# Context manager used to switch among the available cuda devices
|
# Context manager used to switch among the available cuda devices
|
||||||
with torch.cuda.device(input.device):
|
# with torch.cuda.device(input.device):
|
||||||
qinput, scale = act_quant(input, self.block_size[1])
|
qinput, scale = act_quant(input, self.block_size[1])
|
||||||
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
|
# Blocks the CPU until all CUDA operations on the specified device are complete. It is used to ensure that the results of the
|
||||||
# preceding operations are ready before proceeding
|
# preceding operations are ready before proceeding
|
||||||
torch.cuda.synchronize(device=input.device)
|
# torch.cuda.synchronize(device=self.weight.device)
|
||||||
with torch.cuda.device(input.device):
|
with torch.cuda.device(input.device):
|
||||||
output = w8a8_block_fp8_matmul_triton(
|
output = w8a8_block_fp8_matmul_triton(
|
||||||
qinput,
|
qinput,
|
||||||
@ -344,7 +347,7 @@ class FP8Linear(nn.Module):
|
|||||||
self.block_size,
|
self.block_size,
|
||||||
output_dtype=input.dtype,
|
output_dtype=input.dtype,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize(device=input.device)
|
torch.cuda.synchronize()
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
output = output + self.bias
|
output = output + self.bias
|
||||||
return output.to(dtype=input.dtype)
|
return output.to(dtype=input.dtype)
|
||||||
@ -352,6 +355,7 @@ class FP8Linear(nn.Module):
|
|||||||
|
|
||||||
def _replace_with_fp8_linear(
|
def _replace_with_fp8_linear(
|
||||||
model,
|
model,
|
||||||
|
tp_plan=None,
|
||||||
modules_to_not_convert=None,
|
modules_to_not_convert=None,
|
||||||
current_key_name=None,
|
current_key_name=None,
|
||||||
quantization_config=None,
|
quantization_config=None,
|
||||||
@ -378,10 +382,12 @@ def _replace_with_fp8_linear(
|
|||||||
block_size=quantization_config.weight_block_size,
|
block_size=quantization_config.weight_block_size,
|
||||||
)
|
)
|
||||||
has_been_replaced = True
|
has_been_replaced = True
|
||||||
|
# when changing a layer the TP PLAN for that layer should be updated. TODO
|
||||||
|
|
||||||
if len(list(module.children())) > 0:
|
if len(list(module.children())) > 0:
|
||||||
_, has_been_replaced = _replace_with_fp8_linear(
|
_, has_been_replaced = _replace_with_fp8_linear(
|
||||||
module,
|
module,
|
||||||
|
tp_plan,
|
||||||
modules_to_not_convert,
|
modules_to_not_convert,
|
||||||
current_key_name,
|
current_key_name,
|
||||||
quantization_config,
|
quantization_config,
|
||||||
@ -404,9 +410,9 @@ def replace_with_fp8_linear(
|
|||||||
if quantization_config.modules_to_not_convert is not None:
|
if quantization_config.modules_to_not_convert is not None:
|
||||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||||
modules_to_not_convert = list(set(modules_to_not_convert))
|
modules_to_not_convert = list(set(modules_to_not_convert))
|
||||||
|
|
||||||
model, has_been_replaced = _replace_with_fp8_linear(
|
model, has_been_replaced = _replace_with_fp8_linear(
|
||||||
model,
|
model,
|
||||||
|
tp_plan=model._tp_plan,
|
||||||
modules_to_not_convert=modules_to_not_convert,
|
modules_to_not_convert=modules_to_not_convert,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
)
|
)
|
||||||
|
@ -231,8 +231,8 @@ class IsolatedParallel(TensorParallelLayer):
|
|||||||
distribute_module(
|
distribute_module(
|
||||||
module,
|
module,
|
||||||
device_mesh,
|
device_mesh,
|
||||||
partial(self._prepare_input_fn),
|
partial(self._prepare_input_fn, None, None),
|
||||||
partial(self._prepare_output_fn),
|
partial(self._prepare_output_fn, None, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -484,7 +484,12 @@ def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, curr
|
|||||||
# 1. We add hooks to the layer being loaded:
|
# 1. We add hooks to the layer being loaded:
|
||||||
if current_module_plan is not None:
|
if current_module_plan is not None:
|
||||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||||
tp_layer.prepare_module_tp(module, device_mesh)
|
try:
|
||||||
|
tp_layer.prepare_module_tp(module, device_mesh)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
print(
|
||||||
|
f"Trying to prepare {layer_name}, but it's not supported. Corresponding module: {module} Fix it's TP plan: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
# 2. We add hooks to the parrent module if needed
|
# 2. We add hooks to the parrent module if needed
|
||||||
if "." in layer_name:
|
if "." in layer_name:
|
||||||
@ -531,6 +536,7 @@ def shard_and_distribute_module(
|
|||||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# TODO log no plan modules in set
|
||||||
param = param[...].to(param_casting_dtype)
|
param = param[...].to(param_casting_dtype)
|
||||||
if is_contiguous:
|
if is_contiguous:
|
||||||
param = param.contiguous()
|
param = param.contiguous()
|
||||||
|
@ -189,13 +189,31 @@ def _compute_yarn_parameters(
|
|||||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
dim = int(head_dim * partial_rotary_factor)
|
dim = int(head_dim * partial_rotary_factor)
|
||||||
max_position_embeddings = config.max_position_embeddings
|
|
||||||
factor = config.rope_scaling["factor"]
|
factor = config.rope_scaling["factor"]
|
||||||
|
attention_factor = config.rope_scaling.get("attention_factor")
|
||||||
|
mscale = config.rope_scaling.get("mscale")
|
||||||
|
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
|
||||||
|
|
||||||
|
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
|
||||||
|
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
||||||
|
# values to compute the default attention scaling factor, instead of using `factor`.
|
||||||
|
if "original_max_position_embeddings" in config.rope_scaling:
|
||||||
|
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
|
||||||
|
factor = config.max_position_embeddings / original_max_position_embeddings
|
||||||
|
else:
|
||||||
|
original_max_position_embeddings = config.max_position_embeddings
|
||||||
|
|
||||||
|
def get_mscale(scale, mscale=1):
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
# Sets the attention factor as suggested in the paper
|
# Sets the attention factor as suggested in the paper
|
||||||
attention_factor = config.rope_scaling.get("attention_factor")
|
|
||||||
if attention_factor is None:
|
if attention_factor is None:
|
||||||
attention_factor = 0.1 * math.log(factor) + 1.0
|
if mscale and mscale_all_dim:
|
||||||
|
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
|
||||||
|
else:
|
||||||
|
attention_factor = get_mscale(factor)
|
||||||
|
|
||||||
# Optional config options
|
# Optional config options
|
||||||
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
||||||
@ -227,7 +245,7 @@ def _compute_yarn_parameters(
|
|||||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||||
|
|
||||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
|
||||||
|
|
||||||
# Get n-dimensional rotational scaling corrected for extrapolation
|
# Get n-dimensional rotational scaling corrected for extrapolation
|
||||||
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
||||||
@ -235,7 +253,6 @@ def _compute_yarn_parameters(
|
|||||||
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
||||||
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
||||||
)
|
)
|
||||||
|
|
||||||
return inv_freq, attention_factor
|
return inv_freq, attention_factor
|
||||||
|
|
||||||
|
|
||||||
@ -425,7 +442,14 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
|
|||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor"}
|
required_keys = {"rope_type", "factor"}
|
||||||
optional_keys = {"attention_factor", "beta_fast", "beta_slow", "original_max_position_embeddings"}
|
optional_keys = {
|
||||||
|
"attention_factor",
|
||||||
|
"beta_fast",
|
||||||
|
"beta_slow",
|
||||||
|
"original_max_position_embeddings",
|
||||||
|
"mscale",
|
||||||
|
"mscale_all_dim",
|
||||||
|
}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
@ -779,8 +779,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||||
|
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_quantized
|
is_meta_state_dict = shard_file.endswith(".safetensors")
|
||||||
|
|
||||||
file_pointer = None
|
file_pointer = None
|
||||||
if is_meta_state_dict:
|
if is_meta_state_dict:
|
||||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||||
@ -795,7 +794,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
serialized_param_name = reverse_renaming_mapping[param_name]
|
serialized_param_name = reverse_renaming_mapping[param_name]
|
||||||
param = file_pointer.get_slice(serialized_param_name)
|
param = file_pointer.get_slice(serialized_param_name)
|
||||||
else:
|
else:
|
||||||
param = empty_param # It is actually not empty!
|
param = empty_param.to(tensor_device) # It is actually not empty!
|
||||||
|
|
||||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||||
model,
|
model,
|
||||||
@ -813,7 +812,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
param_name,
|
param_name,
|
||||||
casting_dtype,
|
casting_dtype,
|
||||||
to_contiguous,
|
to_contiguous,
|
||||||
tensor_device, # the rank
|
int(os.environ["RANK"]), # the rank
|
||||||
device_mesh,
|
device_mesh,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -4102,11 +4101,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
try:
|
try:
|
||||||
logger.warning("Tensor Parallel requires torch.distributed to be initialized first.")
|
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
|
torch.distributed.init_process_group(
|
||||||
torch.cuda.set_device(rank)
|
"nccl", rank=rank, world_size=world_size, init_method="env://"
|
||||||
|
)
|
||||||
|
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
"We tried to initialize torch.distributed for you, but it failed, make"
|
"We tried to initialize torch.distributed for you, but it failed, make"
|
||||||
@ -4115,12 +4115,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
|
||||||
device_type = torch._C._get_accelerator().type
|
device_type = torch._C._get_accelerator().type
|
||||||
device_module = torch.get_device_module(device_type)
|
tp_device = torch.device(device_type, torch.cuda.current_device())
|
||||||
# Get device with index assuming equal number of devices per host
|
if tp_device.index > 0:
|
||||||
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
|
import sys
|
||||||
|
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
# This is the easiest way to dispatch to the current process device
|
# This is the easiest way to dispatch to the current process device
|
||||||
device_map = tp_device
|
device_map = tp_device
|
||||||
|
|
||||||
# Assuming sharding the model onto the world
|
# Assuming sharding the model onto the world
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))
|
||||||
@ -4871,9 +4872,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||||
|
|
||||||
# Warmup cuda to load the weights much faster on devices
|
# Warmup cuda to load the weights much faster on devices
|
||||||
if device_map is not None and hf_quantizer is None:
|
if device_map is not None: # and hf_quantizer is None:
|
||||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||||
caching_allocator_warmup(model_to_load, expanded_device_map)
|
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||||
|
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
mismatched_keys = []
|
mismatched_keys = []
|
||||||
@ -5834,7 +5835,7 @@ def expand_device_map(device_map, param_names):
|
|||||||
return new_device_map
|
return new_device_map
|
||||||
|
|
||||||
|
|
||||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
|
||||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||||
the model, which is actually the loading speed botteneck.
|
the model, which is actually the loading speed botteneck.
|
||||||
@ -5865,7 +5866,6 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
|||||||
if _torch_distributed_available and torch.distributed.is_initialized()
|
if _torch_distributed_available and torch.distributed.is_initialized()
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
total_byte_count = defaultdict(lambda: 0)
|
total_byte_count = defaultdict(lambda: 0)
|
||||||
for param_name, device in accelerator_device_map.items():
|
for param_name, device in accelerator_device_map.items():
|
||||||
param = model.get_parameter_or_buffer(param_name)
|
param = model.get_parameter_or_buffer(param_name)
|
||||||
@ -5886,7 +5886,7 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
|
|||||||
# Allow up to 95% of max device memory
|
# Allow up to 95% of max device memory
|
||||||
byte_count = min(byte_count, int(0.95 * device_memory))
|
byte_count = min(byte_count, int(0.95 * device_memory))
|
||||||
# Allocate memory
|
# Allocate memory
|
||||||
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
|
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
def get_disk_only_shard_files(device_map, weight_map):
|
def get_disk_only_shard_files(device_map, weight_map):
|
||||||
|
@ -71,6 +71,7 @@ from . import (
|
|||||||
deberta,
|
deberta,
|
||||||
deberta_v2,
|
deberta_v2,
|
||||||
decision_transformer,
|
decision_transformer,
|
||||||
|
deepseek_v3,
|
||||||
deformable_detr,
|
deformable_detr,
|
||||||
deit,
|
deit,
|
||||||
deprecated,
|
deprecated,
|
||||||
|
@ -89,6 +89,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("deberta", "DebertaConfig"),
|
("deberta", "DebertaConfig"),
|
||||||
("deberta-v2", "DebertaV2Config"),
|
("deberta-v2", "DebertaV2Config"),
|
||||||
("decision_transformer", "DecisionTransformerConfig"),
|
("decision_transformer", "DecisionTransformerConfig"),
|
||||||
|
("deepseek_v3", "DeepseekV3Config"),
|
||||||
("deformable_detr", "DeformableDetrConfig"),
|
("deformable_detr", "DeformableDetrConfig"),
|
||||||
("deit", "DeiTConfig"),
|
("deit", "DeiTConfig"),
|
||||||
("depth_anything", "DepthAnythingConfig"),
|
("depth_anything", "DepthAnythingConfig"),
|
||||||
@ -423,6 +424,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("deberta", "DeBERTa"),
|
("deberta", "DeBERTa"),
|
||||||
("deberta-v2", "DeBERTa-v2"),
|
("deberta-v2", "DeBERTa-v2"),
|
||||||
("decision_transformer", "Decision Transformer"),
|
("decision_transformer", "Decision Transformer"),
|
||||||
|
("deepseek_v3", "DeepSeek-V3"),
|
||||||
("deformable_detr", "Deformable DETR"),
|
("deformable_detr", "Deformable DETR"),
|
||||||
("deit", "DeiT"),
|
("deit", "DeiT"),
|
||||||
("deplot", "DePlot"),
|
("deplot", "DePlot"),
|
||||||
|
@ -88,6 +88,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("deberta", "DebertaModel"),
|
("deberta", "DebertaModel"),
|
||||||
("deberta-v2", "DebertaV2Model"),
|
("deberta-v2", "DebertaV2Model"),
|
||||||
("decision_transformer", "DecisionTransformerModel"),
|
("decision_transformer", "DecisionTransformerModel"),
|
||||||
|
("deepseek_v3", "DeepseekV3Model"),
|
||||||
("deformable_detr", "DeformableDetrModel"),
|
("deformable_detr", "DeformableDetrModel"),
|
||||||
("deit", "DeiTModel"),
|
("deit", "DeiTModel"),
|
||||||
("depth_pro", "DepthProModel"),
|
("depth_pro", "DepthProModel"),
|
||||||
@ -514,6 +515,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("ctrl", "CTRLLMHeadModel"),
|
("ctrl", "CTRLLMHeadModel"),
|
||||||
("data2vec-text", "Data2VecTextForCausalLM"),
|
("data2vec-text", "Data2VecTextForCausalLM"),
|
||||||
("dbrx", "DbrxForCausalLM"),
|
("dbrx", "DbrxForCausalLM"),
|
||||||
|
("deepseek_v3", "DeepseekV3ForCausalLM"),
|
||||||
("diffllama", "DiffLlamaForCausalLM"),
|
("diffllama", "DiffLlamaForCausalLM"),
|
||||||
("electra", "ElectraForCausalLM"),
|
("electra", "ElectraForCausalLM"),
|
||||||
("emu3", "Emu3ForCausalLM"),
|
("emu3", "Emu3ForCausalLM"),
|
||||||
|
@ -171,6 +171,13 @@ else:
|
|||||||
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
|
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"deepseek_v3",
|
||||||
|
(
|
||||||
|
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||||
|
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||||
|
),
|
||||||
|
),
|
||||||
(
|
(
|
||||||
"diffllama",
|
"diffllama",
|
||||||
(
|
(
|
||||||
|
27
src/transformers/models/deepseek_v3/__init__.py
Normal file
27
src/transformers/models/deepseek_v3/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_deepseek_v3 import *
|
||||||
|
from .modeling_deepseek_v3 import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
247
src/transformers/models/deepseek_v3/configuration_deepseek_v3.py
Normal file
247
src/transformers/models/deepseek_v3/configuration_deepseek_v3.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3)
|
||||||
|
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""DeepSeekV3 model configuration"""
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...modeling_rope_utils import rope_config_validation
|
||||||
|
|
||||||
|
|
||||||
|
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
||||||
|
e.g. [bzantium/tiny-deepseek-v3](https://huggingface.co/bzantium/tiny-deepseek-v3)
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 129280):
|
||||||
|
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`DeepseekV3Model`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 7168):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 18432):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
moe_intermediate_size (`int`, *optional*, defaults to 2048):
|
||||||
|
Dimension of the MoE representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 61):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 128):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 128):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
n_shared_experts (`int`, *optional*, defaults to 1):
|
||||||
|
Number of shared experts.
|
||||||
|
n_routed_experts (`int`, *optional*, defaults to 256):
|
||||||
|
Number of routed experts.
|
||||||
|
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
|
||||||
|
Scaling factor or routed experts.
|
||||||
|
kv_lora_rank (`int`, *optional*, defaults to 512):
|
||||||
|
Rank of the LoRA matrices for key and value projections.
|
||||||
|
q_lora_rank (`int`, *optional*, defaults to 1536):
|
||||||
|
Rank of the LoRA matrices for query projections.
|
||||||
|
qk_rope_head_dim (`int`, *optional*, defaults to 64):
|
||||||
|
Dimension of the query/key heads that use rotary position embeddings.
|
||||||
|
v_head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
Dimension of the value heads.
|
||||||
|
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
Dimension of the query/key heads that don't use rotary position embeddings.
|
||||||
|
n_group (`int`, *optional*, defaults to 8):
|
||||||
|
Number of groups for routed experts.
|
||||||
|
topk_group (`int`, *optional*, defaults to 4):
|
||||||
|
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||||
|
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
||||||
|
Number of selected experts, None means dense model.
|
||||||
|
first_k_dense_replace (`int`, *optional*, defaults to 3):
|
||||||
|
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||||
|
\--k dense layers--/
|
||||||
|
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to normalize the weights of the routed experts.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
Padding token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 0):
|
||||||
|
Beginning of stream token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
End of stream token id.
|
||||||
|
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||||
|
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||||
|
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||||
|
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||||
|
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||||
|
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||||
|
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum.
|
||||||
|
rope_interleave (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to interleave the rotary position embeddings.
|
||||||
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import DeepseekV3Model, DeepseekV3Config
|
||||||
|
|
||||||
|
>>> # Initializing a Deepseek-V3 style configuration
|
||||||
|
>>> configuration = DeepseekV3Config()
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "deepseek_v3"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
base_model_tp_plan = { # TODO: only replicate attention layers when > first_k_dense_replace
|
||||||
|
"layers.*.mlp.experts.*.gate_proj": "local_colwise",
|
||||||
|
"layers.*.mlp.experts.*.up_proj": "local_colwise",
|
||||||
|
"layers.*.mlp.experts.*.down_proj": "local_rowwise",
|
||||||
|
"layers.*.mlp.experts.*": "local", # each expert is wrapped in a module list
|
||||||
|
"layers.*.mlp.shared_experts.gate_proj": "local_colwise",
|
||||||
|
"layers.*.mlp.shared_experts.up_proj": "local_colwise",
|
||||||
|
"layers.*.mlp.shared_experts.down_proj": "local_rowwise",
|
||||||
|
"layers.*.mlp.shared_experts": "local",
|
||||||
|
"layers.*.mlp.gate_proj": "local_colwise",
|
||||||
|
"layers.*.mlp.up_proj": "local_colwise",
|
||||||
|
"layers.*.mlp.down_proj": "local_rowwise",
|
||||||
|
"layers.*.mlp": "gather", # This is the only moment where results are gathered
|
||||||
|
}
|
||||||
|
base_model_pp_plan = {
|
||||||
|
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||||
|
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||||
|
"norm": (["hidden_states"], ["hidden_states"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=129280,
|
||||||
|
hidden_size=7168,
|
||||||
|
intermediate_size=18432,
|
||||||
|
moe_intermediate_size=2048,
|
||||||
|
num_hidden_layers=61,
|
||||||
|
num_attention_heads=128,
|
||||||
|
num_key_value_heads=128,
|
||||||
|
n_shared_experts=1,
|
||||||
|
n_routed_experts=256,
|
||||||
|
routed_scaling_factor=2.5,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
n_group=8,
|
||||||
|
topk_group=4,
|
||||||
|
num_experts_per_tok=8,
|
||||||
|
first_k_dense_replace=3,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
rope_interleave=True,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.head_dim = qk_rope_head_dim
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.rope_interleave = rope_interleave
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DeepseekV3Config"]
|
1061
src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Normal file
1061
src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Normal file
File diff suppressed because it is too large
Load Diff
368
src/transformers/models/deepseek_v3/modular_deepseek_v3.py
Normal file
368
src/transformers/models/deepseek_v3/modular_deepseek_v3.py
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
import math
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache
|
||||||
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
from ...utils import logging
|
||||||
|
from ..llama.modeling_llama import (
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaPreTrainedModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
eager_attention_forward,
|
||||||
|
rotate_half,
|
||||||
|
)
|
||||||
|
from .configuration_deepseek_v3 import DeepseekV3Config
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3RMSNorm(LlamaRMSNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3RotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
r"""
|
||||||
|
TODO let's just use the original freqcis computation to not have the view
|
||||||
|
transpose + reshape! This is not optimized!
|
||||||
|
Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`):
|
||||||
|
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||||
|
used to pass offsetted position ids when working with a KV-cache.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
|
||||||
|
b, h, s, d = q.shape
|
||||||
|
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||||
|
|
||||||
|
b, h, s, d = k.shape
|
||||||
|
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
def yarn_get_mscale(scale=1, mscale=1):
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MLP(nn.Module):
|
||||||
|
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
||||||
|
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3TopkRouter(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
self.n_routed_experts = config.n_routed_experts
|
||||||
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
|
self.n_group = config.n_group
|
||||||
|
self.topk_group = config.topk_group
|
||||||
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||||
|
self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_topk_indices(self, scores):
|
||||||
|
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||||
|
group_scores = (
|
||||||
|
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||||
|
.topk(2, dim=-1)[0]
|
||||||
|
.sum(dim=-1)
|
||||||
|
)
|
||||||
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||||
|
group_mask = torch.zeros_like(group_scores)
|
||||||
|
group_mask.scatter_(1, group_idx, 1)
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||||
|
.reshape(-1, self.n_routed_experts)
|
||||||
|
)
|
||||||
|
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||||
|
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||||
|
return topk_indices
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||||
|
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||||
|
scores = router_logits.sigmoid()
|
||||||
|
topk_indices = self.get_topk_indices(scores)
|
||||||
|
topk_weights = scores.gather(1, topk_indices)
|
||||||
|
if self.norm_topk_prob:
|
||||||
|
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
|
topk_weights /= denominator
|
||||||
|
topk_weights = topk_weights * self.routed_scaling_factor
|
||||||
|
return topk_indices, topk_weights
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
"""
|
||||||
|
A mixed expert module containing shared experts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.experts = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||||
|
for _ in range(config.n_routed_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.gate = DeepseekV3TopkRouter(config)
|
||||||
|
self.shared_experts = DeepseekV3MLP(
|
||||||
|
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
|
||||||
|
)
|
||||||
|
|
||||||
|
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||||
|
r"""
|
||||||
|
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
|
||||||
|
to not have to do a loop here (deepseek has 256 experts soooo yeah).
|
||||||
|
"""
|
||||||
|
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
|
||||||
|
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
|
||||||
|
expert_mask = expert_mask.permute(2, 0, 1)
|
||||||
|
|
||||||
|
for expert_idx in range(len(self.experts)):
|
||||||
|
expert = self.experts[expert_idx]
|
||||||
|
mask = expert_mask[expert_idx]
|
||||||
|
token_indices, weight_indices = torch.where(mask)
|
||||||
|
|
||||||
|
if token_indices.numel() > 0:
|
||||||
|
expert_weights = topk_weights[token_indices, weight_indices]
|
||||||
|
expert_input = hidden_states[token_indices]
|
||||||
|
expert_output = expert(expert_input)
|
||||||
|
weighted_output = expert_output * expert_weights.unsqueeze(-1)
|
||||||
|
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
||||||
|
|
||||||
|
# in original deepseek, the output of the experts are gathered once we leave this module
|
||||||
|
# thus the moe module is itelsf an IsolatedParallel module
|
||||||
|
# and all expert are "local" meaning we shard but we don't gather
|
||||||
|
return final_hidden_states.type(hidden_states.dtype)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
residuals = hidden_states
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
topk_indices, topk_weights = self.gate(hidden_states)
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||||
|
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Attention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.q_lora_rank = config.q_lora_rank
|
||||||
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
|
self.v_head_dim = config.v_head_dim
|
||||||
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
|
self.qk_head_dim = config.qk_head_dim
|
||||||
|
|
||||||
|
self.is_causal = True
|
||||||
|
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||||
|
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
|
||||||
|
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
|
||||||
|
self.kv_b_proj = nn.Linear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scaling = self.qk_head_dim ** (-0.5)
|
||||||
|
if self.config.rope_scaling is not None:
|
||||||
|
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if mscale_all_dim:
|
||||||
|
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||||
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
batch_size, seq_length = hidden_states.shape[:-1]
|
||||||
|
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
|
||||||
|
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
|
||||||
|
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
|
||||||
|
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
|
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
|
||||||
|
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
|
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
if self.config.rope_interleave: # support using interleaved weights for efficiency
|
||||||
|
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
|
||||||
|
else:
|
||||||
|
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
|
||||||
|
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
|
||||||
|
|
||||||
|
query_states = torch.cat((q_pass, q_rot), dim=-1)
|
||||||
|
key_states = torch.cat((k_pass, k_rot), dim=-1)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||||
|
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
|
||||||
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
|
logger.warning_once(
|
||||||
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||||
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
|
scaling=self.scaling,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||||
|
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3DecoderLayer(LlamaDecoderLayer, nn.Module):
|
||||||
|
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||||
|
nn.Module().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
|
if layer_idx >= config.first_k_dense_replace:
|
||||||
|
self.mlp = DeepseekV3MoE(config)
|
||||||
|
else:
|
||||||
|
self.mlp = DeepseekV3MLP(config)
|
||||||
|
|
||||||
|
self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
|
||||||
|
def _init_weights(self, module):
|
||||||
|
std = self.config.initializer_range
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, DeepseekV3TopkRouter):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
elif isinstance(module, nn.Parameter):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Model(LlamaModel):
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"]
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3ForCausalLM(LlamaForCausalLM):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DeepseekV3PreTrainedModel",
|
||||||
|
"DeepseekV3Model",
|
||||||
|
"DeepseekV3ForCausalLM",
|
||||||
|
]
|
@ -2812,6 +2812,27 @@ class DecisionTransformerPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3ForCausalLM(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Model(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3PreTrainedModel(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class DeformableDetrForObjectDetection(metaclass=DummyObject):
|
class DeformableDetrForObjectDetection(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
0
tests/models/deepseek_v3/__init__.py
Normal file
0
tests/models/deepseek_v3/__init__.py
Normal file
657
tests/models/deepseek_v3/test_modeling_deepseek_v3.py
Normal file
657
tests/models/deepseek_v3/test_modeling_deepseek_v3.py
Normal file
@ -0,0 +1,657 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch DeepseekV3 model."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, DeepseekV3Config, is_torch_available, set_seed
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_read_token,
|
||||||
|
require_torch,
|
||||||
|
require_torch_accelerator,
|
||||||
|
require_torch_sdpa,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
DeepseekV3ForCausalLM,
|
||||||
|
DeepseekV3Model,
|
||||||
|
)
|
||||||
|
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
|
||||||
|
DeepseekV3RotaryEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3ModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=False,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=37,
|
||||||
|
moe_intermediate_size=12,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
n_shared_experts=1,
|
||||||
|
n_routed_experts=8,
|
||||||
|
routed_scaling_factor=2.5,
|
||||||
|
kv_lora_rank=16,
|
||||||
|
q_lora_rank=32,
|
||||||
|
qk_rope_head_dim=16,
|
||||||
|
v_head_dim=32,
|
||||||
|
qk_nope_head_dim=32,
|
||||||
|
n_group=2,
|
||||||
|
topk_group=1,
|
||||||
|
num_experts_per_tok=8,
|
||||||
|
first_k_dense_replace=2,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
aux_loss_alpha=0.001,
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=512,
|
||||||
|
initializer_range=0.02,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
pad_token_id=0,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return DeepseekV3Config(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
moe_intermediate_size=self.moe_intermediate_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
num_key_value_heads=self.num_key_value_heads,
|
||||||
|
n_shared_experts=self.n_shared_experts,
|
||||||
|
n_routed_experts=self.n_routed_experts,
|
||||||
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
q_lora_rank=self.q_lora_rank,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
v_head_dim=self.v_head_dim,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
n_group=self.n_group,
|
||||||
|
topk_group=self.topk_group,
|
||||||
|
num_experts_per_tok=self.num_experts_per_tok,
|
||||||
|
first_k_dense_replace=self.first_k_dense_replace,
|
||||||
|
norm_topk_prob=self.norm_topk_prob,
|
||||||
|
aux_loss_alpha=self.aux_loss_alpha,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = DeepseekV3Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_model_as_decoder(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = DeepseekV3Model(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def create_and_check_for_causal_lm(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
model = DeepseekV3ForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = DeepseekV3ForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(
|
||||||
|
next_input_ids,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
output_from_past = model(
|
||||||
|
next_tokens,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
DeepseekV3Model,
|
||||||
|
DeepseekV3ForCausalLM,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
all_generative_model_classes = (DeepseekV3ForCausalLM,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"feature-extraction": DeepseekV3Model,
|
||||||
|
"text-generation": DeepseekV3ForCausalLM,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
fx_compatible = False
|
||||||
|
|
||||||
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||||
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
|
|
||||||
|
# used in `test_torch_compile_for_training`
|
||||||
|
_torch_compile_train_cls = DeepseekV3ForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = DeepseekV3ModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=DeepseekV3Config, hidden_size=37)
|
||||||
|
|
||||||
|
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||||
|
def test_model_outputs_equivalence(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_assisted_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache which is not compatible with dola decoding")
|
||||||
|
def test_dola_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache and doesn't support continue from past kv")
|
||||||
|
def test_generate_continue_from_past_key_values(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache and doesn't support low_memory generation")
|
||||||
|
def test_beam_search_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3 has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
|
||||||
|
)
|
||||||
|
def test_generate_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
|
||||||
|
)
|
||||||
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
|
||||||
|
)
|
||||||
|
def test_generate_continue_from_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("DeepseekV3's eager attn/sdpa attn outputs are expected to be different")
|
||||||
|
def test_sdpa_equivalence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
|
||||||
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
|
||||||
|
def test_generate_compilation_all_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
|
||||||
|
def test_generate_compile_model_forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
|
||||||
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_model_various_embeddings(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||||
|
config_and_inputs[0].position_embedding_type = type
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@parameterized.expand([("yarn",)])
|
||||||
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
original_model = DeepseekV3Model(config)
|
||||||
|
original_model.to(torch_device)
|
||||||
|
original_model.eval()
|
||||||
|
original_short_output = original_model(short_input).last_hidden_state
|
||||||
|
original_long_output = original_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||||
|
scaled_model = DeepseekV3Model(config)
|
||||||
|
scaled_model.to(torch_device)
|
||||||
|
scaled_model.eval()
|
||||||
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||||
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||||
|
# maximum sequence length, so the outputs for the short input should match.
|
||||||
|
if scaling_type == "dynamic":
|
||||||
|
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
||||||
|
else:
|
||||||
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
|
||||||
|
# The output should be different for long inputs
|
||||||
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|
||||||
|
def test_model_rope_scaling(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
scaling_factor = 10
|
||||||
|
short_input_length = 10
|
||||||
|
long_input_length = int(config.max_position_embeddings * 1.5)
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||||
|
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||||
|
position_ids_short = position_ids_short.unsqueeze(0)
|
||||||
|
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||||
|
position_ids_long = position_ids_long.unsqueeze(0)
|
||||||
|
|
||||||
|
# Sanity check original RoPE
|
||||||
|
original_rope = DeepseekV3RotaryEmbedding(config=config).to(torch_device)
|
||||||
|
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||||
|
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||||
|
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||||
|
|
||||||
|
# Sanity check linear RoPE scaling
|
||||||
|
# New position "x" should match original position with index "x/scaling_factor"
|
||||||
|
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
||||||
|
linear_scaling_rope = DeepseekV3RotaryEmbedding(config=config).to(torch_device)
|
||||||
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||||
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||||
|
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||||
|
for new_position in range(0, long_input_length, scaling_factor):
|
||||||
|
original_position = int(new_position // scaling_factor)
|
||||||
|
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||||
|
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||||
|
|
||||||
|
# Sanity check Dynamic NTK RoPE scaling
|
||||||
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||||
|
# with scaling_factor (or that `inv_freq` decreases)
|
||||||
|
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
||||||
|
ntk_scaling_rope = DeepseekV3RotaryEmbedding(config=config).to(torch_device)
|
||||||
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||||
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||||
|
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||||
|
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||||
|
|
||||||
|
# Sanity check Yarn RoPE scaling
|
||||||
|
# Scaling should be over the entire input
|
||||||
|
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
||||||
|
yarn_scaling_rope = DeepseekV3RotaryEmbedding(config=config).to(torch_device)
|
||||||
|
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
||||||
|
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
||||||
|
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Deepseek-V3 uses MLA on all models so the KV cache is a non standard format")
|
||||||
|
def test_past_key_values_format(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
|
@slow
|
||||||
|
def test_eager_matches_sdpa_generate(self):
|
||||||
|
"""
|
||||||
|
Overwritting the common test as the test is flaky on tiny models
|
||||||
|
"""
|
||||||
|
max_new_tokens = 30
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bzantium/tiny-deepseek-v3")
|
||||||
|
|
||||||
|
model_sdpa = DeepseekV3ForCausalLM.from_pretrained(
|
||||||
|
"bzantium/tiny-deepseek-v3",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||||
|
|
||||||
|
model_eager = DeepseekV3ForCausalLM.from_pretrained(
|
||||||
|
"bzantium/tiny-deepseek-v3",
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
attn_implementation="eager",
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"hi here's a longer context, getting longer and",
|
||||||
|
"Hello this is a very long sentence my friend, very long for real",
|
||||||
|
"Today I am in Paris and",
|
||||||
|
]
|
||||||
|
|
||||||
|
for padding_side in ["left", "right"]:
|
||||||
|
tokenizer.padding_side = padding_side
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||||
|
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||||
|
|
||||||
|
with self.subTest(f"{padding_side}"):
|
||||||
|
torch.testing.assert_close(
|
||||||
|
res_eager,
|
||||||
|
res_sdpa,
|
||||||
|
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_accelerator
|
||||||
|
class DeepseekV3IntegrationTest(unittest.TestCase):
|
||||||
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
|
# Depending on the hardware we get different logits / generations
|
||||||
|
cuda_compute_capability_major_version = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if is_torch_available() and torch.cuda.is_available():
|
||||||
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_accelerator
|
||||||
|
@require_read_token
|
||||||
|
def test_compile_static_cache(self):
|
||||||
|
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||||
|
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
|
NUM_TOKENS_TO_GENERATE = 40
|
||||||
|
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||||
|
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||||
|
EXPECTED_TEXT_COMPLETION = [
|
||||||
|
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||||
|
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||||
|
"theory of relativ",
|
||||||
|
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
||||||
|
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Simply put, the theory of relativity states that ",
|
||||||
|
"My favorite all time favorite condiment is ketchup.",
|
||||||
|
]
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bzantium/tiny-deepseek-v3", pad_token="</s>", padding_side="right")
|
||||||
|
model = DeepseekV3ForCausalLM.from_pretrained(
|
||||||
|
"bzantium/tiny-deepseek-v3", device_map=torch_device, torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
|
# Dynamic Cache
|
||||||
|
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||||
|
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||||
|
|
||||||
|
# Static Cache
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||||
|
)
|
||||||
|
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||||
|
|
||||||
|
# Static Cache + compile
|
||||||
|
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
|
||||||
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||||
|
)
|
||||||
|
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
Loading…
Reference in New Issue
Block a user