mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add DeepSeek V2 Model into Transformers (#36400)
* add initial structure * doc fixes, add model base logic * update init files * some fixes to config and modular * some improvements for attention * format * remove unused attn * some fixes for moe layer and for decoder * adapt _compute_yarn_parameters for deepseek * format * small fix * fix for decoder forward * add tests, small refactoring * fix dummies * fix init * fix doc * fix config docs * add sequce doc, fix init for gate * fix issues in tests * fix config doc * remove unused args * some fixes and refactoring after review * fix doc for config * small fixes for config args * revert config refactoring * small refactoring * minor fixes after rebase * small fix after merge * fix modular * remove rotaryembd from public init * small test fix * some rotary pos calculation improvement * fix format * some improvements and fixes * fix config * some refactoring * adjust some unit tests * skip test * small fixes and tests adjustment * reapply modular * fix all tests except Integration * fix integration testzs * cleanup BC stuff * rope * fix integrations tests based on a10 * style --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
parent
accbd8e0fe
commit
c980904204
@ -709,6 +709,8 @@
|
||||
title: D-FINE
|
||||
- local: model_doc/dab-detr
|
||||
title: DAB-DETR
|
||||
- local: model_doc/deepseek_v2
|
||||
title: DeepSeek-V2
|
||||
- local: model_doc/deformable_detr
|
||||
title: Deformable DETR
|
||||
- local: model_doc/deit
|
||||
|
49
docs/source/en/model_doc/deepseek_v2.md
Normal file
49
docs/source/en/model_doc/deepseek_v2.md
Normal file
@ -0,0 +1,49 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# DeepSeek-V2
|
||||
|
||||
## Overview
|
||||
|
||||
The DeepSeek-V2 model was proposed in [DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model](https://arxiv.org/abs/2405.04434) by DeepSeek-AI Team.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
We present DeepSeek-V2, a strong Mixture-of-Experts (MoE) language model characterized by economical training and efficient inference. It comprises 236B total parameters, of which 21B are activated for each token, and supports a context length of 128K tokens. DeepSeek-V2 adopts innovative architectures including Multi-head Latent Attention (MLA) and DeepSeekMoE. MLA guarantees efficient inference through significantly compressing the Key-Value (KV) cache into a latent vector, while DeepSeekMoE enables training strong models at an economical cost through sparse computation. Compared with DeepSeek 67B, DeepSeek-V2 achieves significantly stronger performance, and meanwhile saves 42.5% of training costs, reduces the KV cache by 93.3%, and boosts the maximum generation throughput to 5.76 times. We pretrain DeepSeek-V2 on a high-quality and multi-source corpus consisting of 8.1T tokens, and further perform Supervised Fine-Tuning (SFT) and Reinforcement Learning (RL) to fully unlock its potential. Evaluation results show that, even with only 21B activated parameters, DeepSeek-V2 and its chat versions still achieve top-tier performance among open-source models.
|
||||
|
||||
This model was contributed by [VladOS95-cyber](https://github.com/VladOS95-cyber).
|
||||
The original code can be found [here](https://huggingface.co/deepseek-ai/DeepSeek-V2).
|
||||
|
||||
### Usage tips
|
||||
The model uses Multi-head Latent Attention (MLA) and DeepSeekMoE architectures for efficient inference and cost-effective training. It employs an auxiliary-loss-free strategy for load balancing and multi-token prediction training objective. The model can be used for various language tasks after being pre-trained on 14.8 trillion tokens and going through Supervised Fine-Tuning and Reinforcement Learning stages.
|
||||
|
||||
## DeepseekV2Config
|
||||
|
||||
[[autodoc]] DeepseekV2Config
|
||||
|
||||
## DeepseekV2Model
|
||||
|
||||
[[autodoc]] DeepseekV2Model
|
||||
- forward
|
||||
|
||||
## DeepseekV2ForCausalLM
|
||||
|
||||
[[autodoc]] DeepseekV2ForCausalLM
|
||||
- forward
|
||||
|
||||
## DeepseekV2ForSequenceClassification
|
||||
|
||||
[[autodoc]] DeepseekV2ForSequenceClassification
|
||||
- forward
|
@ -82,6 +82,7 @@ if TYPE_CHECKING:
|
||||
from .deberta import *
|
||||
from .deberta_v2 import *
|
||||
from .decision_transformer import *
|
||||
from .deepseek_v2 import *
|
||||
from .deepseek_v3 import *
|
||||
from .deformable_detr import *
|
||||
from .deit import *
|
||||
|
@ -101,6 +101,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("deberta", "DebertaConfig"),
|
||||
("deberta-v2", "DebertaV2Config"),
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
("deepseek_v2", "DeepseekV2Config"),
|
||||
("deepseek_v3", "DeepseekV3Config"),
|
||||
("deformable_detr", "DeformableDetrConfig"),
|
||||
("deit", "DeiTConfig"),
|
||||
@ -481,6 +482,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("deberta", "DeBERTa"),
|
||||
("deberta-v2", "DeBERTa-v2"),
|
||||
("decision_transformer", "Decision Transformer"),
|
||||
("deepseek_v2", "DeepSeek-V2"),
|
||||
("deepseek_v3", "DeepSeek-V3"),
|
||||
("deformable_detr", "Deformable DETR"),
|
||||
("deit", "DeiT"),
|
||||
|
@ -95,6 +95,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("deberta", "DebertaModel"),
|
||||
("deberta-v2", "DebertaV2Model"),
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
("deepseek_v2", "DeepseekV2Model"),
|
||||
("deepseek_v3", "DeepseekV3Model"),
|
||||
("deformable_detr", "DeformableDetrModel"),
|
||||
("deit", "DeiTModel"),
|
||||
@ -577,6 +578,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("data2vec-text", "Data2VecTextForCausalLM"),
|
||||
("dbrx", "DbrxForCausalLM"),
|
||||
("deepseek_v2", "DeepseekV2ForCausalLM"),
|
||||
("deepseek_v3", "DeepseekV3ForCausalLM"),
|
||||
("diffllama", "DiffLlamaForCausalLM"),
|
||||
("doge", "DogeForCausalLM"),
|
||||
@ -1108,6 +1110,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-text", "Data2VecTextForSequenceClassification"),
|
||||
("deberta", "DebertaForSequenceClassification"),
|
||||
("deberta-v2", "DebertaV2ForSequenceClassification"),
|
||||
("deepseek_v2", "DeepseekV2ForSequenceClassification"),
|
||||
("diffllama", "DiffLlamaForSequenceClassification"),
|
||||
("distilbert", "DistilBertForSequenceClassification"),
|
||||
("doge", "DogeForSequenceClassification"),
|
||||
|
@ -177,6 +177,13 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"deepseek_v2",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"deepseek_v3",
|
||||
(
|
||||
|
29
src/transformers/models/deepseek_v2/__init__.py
Normal file
29
src/transformers/models/deepseek_v2/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_deepseek_v2 import *
|
||||
from .modeling_deepseek_v2 import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
239
src/transformers/models/deepseek_v2/configuration_deepseek_v2.py
Normal file
239
src/transformers/models/deepseek_v2/configuration_deepseek_v2.py
Normal file
@ -0,0 +1,239 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/deepseek_v2/modular_deepseek_v2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_deepseek_v2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class DeepseekV2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate a DeepSeek
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of DeepSeek-V2-Lite" [deepseek-ai/DeepSeek-V2-Lite"](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite").
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the DeepSeek model. Defines the number of different tokens that can be represented by the
|
||||
`input_ids` passed when calling [`DeepseekV2Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
The number of key-value heads used to implement Grouped Query Attention (GQA). If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi-Head Attention (MHA). If
|
||||
`num_key_value_heads=1`, the model will use Multi-Query Attention (MQA). Otherwise, GQA is used.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated normal initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon value used by the RMS normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/value attentions (useful for inference optimization).
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token ID.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning-of-sequence token ID.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End-of-sequence token ID.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie input and output embeddings.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the Rotary Position Embeddings (RoPE).
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Configuration for scaling RoPE embeddings. Supports `linear` and `dynamic` scaling strategies.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value, and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability applied to attention weights.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias term in the MLP layers.
|
||||
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||
Weight coefficient for auxiliary loss in Mixture of Experts (MoE) models.
|
||||
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
||||
Number of dense layers in the shallow layers before switching to MoE layers.
|
||||
kv_lora_rank (`int`, *optional*, defaults to 512):
|
||||
Rank of the LoRA decomposition for key-value projections.
|
||||
q_lora_rank (`int`, *optional*, defaults to 1536):
|
||||
Rank of the LoRA decomposition for query projections.
|
||||
Specifically, it determines the dimensionality to which the query (q) vectors are compressed before being expanded back to their original size.
|
||||
It reduces computational overhead while maintaining model performance.
|
||||
n_group (`int`, *optional*):
|
||||
Number of groups for routed experts.
|
||||
n_routed_experts (`int`, *optional*, defaults to 64):
|
||||
Number of routed experts (None indicates a dense model).
|
||||
n_shared_experts (`int`, *optional*, defaults to 2):
|
||||
Number of shared experts (None indicates a dense model).
|
||||
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
||||
The head dimension for the QK (query-key) projections when using NOPE (Neural Operator Position Encoding).
|
||||
qk_rope_head_dim (`int`, *optional*, defaults to 64):
|
||||
The head dimension for QK projections when using RoPE.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for routed experts in MoE models.
|
||||
seq_aux (`bool`, *optional*, defaults to `True`):
|
||||
Whether to compute the auxiliary loss for each individual sequence.
|
||||
topk_group (`int`, *optional*):
|
||||
Number of selected groups per token for expert selection.
|
||||
topk_method (`str`, *optional*, defaults to `"greedy"`):
|
||||
The method used for selecting top-k experts in the routed gate mechanism.
|
||||
v_head_dim (`int`, *optional*, defaults to 128):
|
||||
The dimension of value projections in the attention layers.
|
||||
num_experts_per_tok (`int`, *optional*):
|
||||
The number of experts selected per token. If `None`, the model behaves as a dense Transformer.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the probability distribution over top-k selected experts.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
||||
Dimension of the MoE (Mixture of Experts) representations.
|
||||
|
||||
```python
|
||||
>>> from transformers import DeepseekV2Model, DeepseekV2Config
|
||||
>>> # Initializing a DeepSeek-V2 style configuration
|
||||
>>> configuration = DeepseekV2Config()
|
||||
>>> # Accessing the model configuration
|
||||
>>> model = DeepseekV2Model(configuration)
|
||||
>>> print(model.config)
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "deepseek_v2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.q_a_proj": "colwise",
|
||||
"layers.*.self_attn.q_b_proj": "colwise",
|
||||
"layers.*.self_attn.kv_b_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
aux_loss_alpha=0.001,
|
||||
first_k_dense_replace=0,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
n_group=None,
|
||||
n_routed_experts=64,
|
||||
n_shared_experts=2,
|
||||
qk_nope_head_dim=128,
|
||||
qk_rope_head_dim=64,
|
||||
routed_scaling_factor=1.0,
|
||||
seq_aux=True,
|
||||
topk_group=None,
|
||||
topk_method="greedy",
|
||||
v_head_dim=128,
|
||||
num_experts_per_tok=None,
|
||||
norm_topk_prob=False,
|
||||
moe_intermediate_size=1407,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
self.head_dim = qk_rope_head_dim
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
self.aux_loss_alpha = aux_loss_alpha
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.n_group = n_group
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.seq_aux = seq_aux
|
||||
self.topk_group = topk_group
|
||||
self.topk_method = topk_method
|
||||
self.v_head_dim = v_head_dim
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
|
||||
|
||||
__all__ = ["DeepseekV2Config"]
|
773
src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
Normal file
773
src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
Normal file
@ -0,0 +1,773 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/deepseek_v2/modular_deepseek_v2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_deepseek_v2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_deepseek_v2 import DeepseekV2Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV2MoEGate(nn.Module):
|
||||
def __init__(self, config: DeepseekV2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
self.topk_method = config.topk_method
|
||||
self.num_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
|
||||
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
# select top-k experts
|
||||
# greedy method is used for DeepSeek-V2-Lite
|
||||
# group_limited_greedy for DeepSeek-V2 and DeepSeek-V2-Chat
|
||||
if self.topk_method == "greedy":
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
elif self.topk_method == "group_limited_greedy":
|
||||
group_scores = scores.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values # [n, num_group]
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, num_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, num_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group)
|
||||
.reshape(batch_size * seq_len, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
topk_weight = topk_weight * self.routed_scaling_factor
|
||||
### expert-level computation auxiliary loss
|
||||
return topk_idx, topk_weight
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
"""
|
||||
A mixed expert module containing shared experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DeepseekV2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.gate = DeepseekV2MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size)
|
||||
self.ep_rank = 0
|
||||
self.experts_per_rank = config.n_routed_experts
|
||||
|
||||
def moe(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
indicies = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = hidden_states[indicies // topk_ids.shape[1]]
|
||||
|
||||
# Process experts
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
end_idx = start_idx + num_tokens
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if outputs else sorted_tokens.new_empty(0)
|
||||
|
||||
# Reorder and combine outputs
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[indicies] = outs
|
||||
hidden_states = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_indices, topk_weights = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class DeepseekV2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
DeepseekV2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class DeepseekV2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: DeepseekV2Config, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
self.rope_type = (
|
||||
config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
if config.rope_scaling is not None
|
||||
else "default"
|
||||
)
|
||||
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
|
||||
# Broadcast to [1, 1, seq_len, dim // 2]
|
||||
freqs_cis = freqs_cis.unsqueeze(1).to(xq_.device)
|
||||
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
class DeepseekV2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
|
||||
self.is_causal = True
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
config.kv_lora_rank,
|
||||
self.num_heads * (self.qk_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
self.scaling = self.qk_head_dim ** (-0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
|
||||
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(query_shape).transpose(1, 2)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
k_nope, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
k_nope = self.kv_b_proj(self.kv_a_layernorm(k_nope)).view(key_shape).transpose(1, 2)
|
||||
k_nope, value_states = torch.split(k_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
||||
q_pe, k_pe = apply_rotary_emb(q_pe, k_pe, position_embeddings.to(q_pe.device))
|
||||
|
||||
k_pe = k_pe.expand(*k_nope.shape[:-1], -1)
|
||||
query_states = torch.cat((q_nope, q_pe), dim=-1)
|
||||
key_states = torch.cat((k_nope, k_pe), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DeepseekV2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config)
|
||||
|
||||
self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DeepseekV2PreTrainedModel(PreTrainedModel):
|
||||
config_class = DeepseekV2Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["DeepseekV2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": DeepseekV2DecoderLayer,
|
||||
"attentions": DeepseekV2Attention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, DeepseekV2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, DeepseekV2MoEGate):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DeepseekV2Model(DeepseekV2PreTrainedModel):
|
||||
def __init__(self, config: DeepseekV2Config):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[DeepseekV2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = DeepseekV2RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position: torch.Tensor = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = DeepseekV2Model(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
|
||||
|
||||
>>> model = DeepseekV2ForCausalLM.from_pretrained("meta-deepseek_v2/DeepseekV2-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v2/DeepseekV2-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
||||
(e.g. GPT-2) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
"""
|
||||
)
|
||||
class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = DeepseekV2Model(config)
|
||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
transformer_outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.pad_token_id is None:
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DeepseekV2PreTrainedModel",
|
||||
"DeepseekV2Model",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV2ForSequenceClassification",
|
||||
]
|
539
src/transformers/models/deepseek_v2/modular_deepseek_v2.py
Normal file
539
src/transformers/models/deepseek_v2/modular_deepseek_v2.py
Normal file
@ -0,0 +1,539 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...utils import (
|
||||
logging,
|
||||
)
|
||||
from ..llama.configuration_llama import LlamaConfig
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRMSNorm,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from ..llama4.modeling_llama4 import Llama4TextRotaryEmbedding
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV2Config(LlamaConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate a DeepSeek
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of DeepSeek-V2-Lite" [deepseek-ai/DeepSeek-V2-Lite"](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite").
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the DeepSeek model. Defines the number of different tokens that can be represented by the
|
||||
`input_ids` passed when calling [`DeepseekV2Model`].
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
The number of key-value heads used to implement Grouped Query Attention (GQA). If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi-Head Attention (MHA). If
|
||||
`num_key_value_heads=1`, the model will use Multi-Query Attention (MQA). Otherwise, GQA is used.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated normal initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon value used by the RMS normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/value attentions (useful for inference optimization).
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token ID.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning-of-sequence token ID.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End-of-sequence token ID.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie input and output embeddings.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the Rotary Position Embeddings (RoPE).
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Configuration for scaling RoPE embeddings. Supports `linear` and `dynamic` scaling strategies.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value, and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability applied to attention weights.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias term in the MLP layers.
|
||||
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||
Weight coefficient for auxiliary loss in Mixture of Experts (MoE) models.
|
||||
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
||||
Number of dense layers in the shallow layers before switching to MoE layers.
|
||||
kv_lora_rank (`int`, *optional*, defaults to 512):
|
||||
Rank of the LoRA decomposition for key-value projections.
|
||||
q_lora_rank (`int`, *optional*, defaults to 1536):
|
||||
Rank of the LoRA decomposition for query projections.
|
||||
Specifically, it determines the dimensionality to which the query (q) vectors are compressed before being expanded back to their original size.
|
||||
It reduces computational overhead while maintaining model performance.
|
||||
n_group (`int`, *optional*):
|
||||
Number of groups for routed experts.
|
||||
n_routed_experts (`int`, *optional*, defaults to 64):
|
||||
Number of routed experts (None indicates a dense model).
|
||||
n_shared_experts (`int`, *optional*, defaults to 2):
|
||||
Number of shared experts (None indicates a dense model).
|
||||
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
||||
The head dimension for the QK (query-key) projections when using NOPE (Neural Operator Position Encoding).
|
||||
qk_rope_head_dim (`int`, *optional*, defaults to 64):
|
||||
The head dimension for QK projections when using RoPE.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor for routed experts in MoE models.
|
||||
seq_aux (`bool`, *optional*, defaults to `True`):
|
||||
Whether to compute the auxiliary loss for each individual sequence.
|
||||
topk_group (`int`, *optional*):
|
||||
Number of selected groups per token for expert selection.
|
||||
topk_method (`str`, *optional*, defaults to `"greedy"`):
|
||||
The method used for selecting top-k experts in the routed gate mechanism.
|
||||
v_head_dim (`int`, *optional*, defaults to 128):
|
||||
The dimension of value projections in the attention layers.
|
||||
num_experts_per_tok (`int`, *optional*):
|
||||
The number of experts selected per token. If `None`, the model behaves as a dense Transformer.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the probability distribution over top-k selected experts.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
||||
Dimension of the MoE (Mixture of Experts) representations.
|
||||
|
||||
```python
|
||||
>>> from transformers import DeepseekV2Model, DeepseekV2Config
|
||||
>>> # Initializing a DeepSeek-V2 style configuration
|
||||
>>> configuration = DeepseekV2Config()
|
||||
>>> # Accessing the model configuration
|
||||
>>> model = DeepseekV2Model(configuration)
|
||||
>>> print(model.config)
|
||||
```
|
||||
"""
|
||||
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.q_a_proj": "colwise",
|
||||
"layers.*.self_attn.q_b_proj": "colwise",
|
||||
"layers.*.self_attn.kv_b_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
model_type = "deepseek_v2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
aux_loss_alpha=0.001,
|
||||
first_k_dense_replace=0,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
n_group=None,
|
||||
n_routed_experts=64,
|
||||
n_shared_experts=2,
|
||||
qk_nope_head_dim=128,
|
||||
qk_rope_head_dim=64,
|
||||
routed_scaling_factor=1.0,
|
||||
seq_aux=True,
|
||||
topk_group=None,
|
||||
topk_method="greedy",
|
||||
v_head_dim=128,
|
||||
num_experts_per_tok=None,
|
||||
norm_topk_prob=False,
|
||||
moe_intermediate_size=1407,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
del self.pretraining_tp
|
||||
self.aux_loss_alpha = aux_loss_alpha
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.n_group = n_group
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.seq_aux = seq_aux
|
||||
self.topk_group = topk_group
|
||||
self.topk_method = topk_method
|
||||
self.v_head_dim = v_head_dim
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.head_dim = qk_rope_head_dim
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
|
||||
# Broadcast to [1, 1, seq_len, dim // 2]
|
||||
freqs_cis = freqs_cis.unsqueeze(1).to(xq_.device)
|
||||
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
class DeepseekV2MoEGate(nn.Module):
|
||||
def __init__(self, config: DeepseekV2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.alpha = config.aux_loss_alpha
|
||||
self.seq_aux = config.seq_aux
|
||||
self.topk_method = config.topk_method
|
||||
self.num_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, hidden_dim = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
|
||||
scores = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
# select top-k experts
|
||||
# greedy method is used for DeepSeek-V2-Lite
|
||||
# group_limited_greedy for DeepSeek-V2 and DeepSeek-V2-Chat
|
||||
if self.topk_method == "greedy":
|
||||
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
||||
elif self.topk_method == "group_limited_greedy":
|
||||
group_scores = scores.view(batch_size * seq_len, self.num_group, -1).max(dim=-1).values # [n, num_group]
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, num_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, num_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(batch_size * seq_len, self.num_group, self.num_experts // self.num_group)
|
||||
.reshape(batch_size * seq_len, -1)
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
|
||||
topk_weight = topk_weight * self.routed_scaling_factor
|
||||
### expert-level computation auxiliary loss
|
||||
return topk_idx, topk_weight
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
"""
|
||||
A mixed expert module containing shared experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DeepseekV2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size))
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.gate = DeepseekV2MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size)
|
||||
self.ep_rank = 0
|
||||
self.experts_per_rank = config.n_routed_experts
|
||||
|
||||
def moe(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
indicies = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = hidden_states[indicies // topk_ids.shape[1]]
|
||||
|
||||
# Process experts
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
end_idx = start_idx + num_tokens
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if outputs else sorted_tokens.new_empty(0)
|
||||
|
||||
# Reorder and combine outputs
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[indicies] = outs
|
||||
hidden_states = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_indices, topk_weights = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DeepseekV2MLP(LlamaMLP):
|
||||
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None):
|
||||
super().__init__(config)
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
|
||||
|
||||
class DeepseekV2RMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class DeepseekV2RotaryEmbedding(Llama4TextRotaryEmbedding):
|
||||
def __init__(self, config: DeepseekV2Config, device=None):
|
||||
super().__init__(config=config, device=device)
|
||||
# BC: "rope_type" was originally "type"
|
||||
self.rope_type = (
|
||||
config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
if config.rope_scaling is not None
|
||||
else "default"
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
|
||||
self.is_causal = True
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
config.kv_lora_rank,
|
||||
self.num_heads * (self.qk_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
self.scaling = self.qk_head_dim ** (-0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
|
||||
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(query_shape).transpose(1, 2)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
k_nope, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
k_nope = self.kv_b_proj(self.kv_a_layernorm(k_nope)).view(key_shape).transpose(1, 2)
|
||||
k_nope, value_states = torch.split(k_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
||||
q_pe, k_pe = apply_rotary_emb(q_pe, k_pe, position_embeddings.to(q_pe.device))
|
||||
|
||||
k_pe = k_pe.expand(*k_nope.shape[:-1], -1)
|
||||
query_states = torch.cat((q_nope, q_pe), dim=-1)
|
||||
key_states = torch.cat((k_nope, k_pe), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(self, config: DeepseekV2Config, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx)
|
||||
self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config)
|
||||
|
||||
self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
class DeepseekV2PreTrainedModel(LlamaPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, DeepseekV2RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, DeepseekV2MoEGate):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
|
||||
class DeepseekV2Model(LlamaModel):
|
||||
pass
|
||||
|
||||
|
||||
class DeepseekV2ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
class DeepseekV2ForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DeepseekV2PreTrainedModel",
|
||||
"DeepseekV2Model",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV2ForSequenceClassification",
|
||||
"DeepseekV2Config",
|
||||
]
|
0
tests/models/deepseek_v2/__init__.py
Normal file
0
tests/models/deepseek_v2/__init__.py
Normal file
269
tests/models/deepseek_v2/test_modeling_deepseek_v2.py
Normal file
269
tests/models/deepseek_v2/test_modeling_deepseek_v2.py
Normal file
@ -0,0 +1,269 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch DeepSeekV2 model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import BitsAndBytesConfig, Cache, DeepseekV2Config, is_torch_available
|
||||
from transformers.testing_utils import require_read_token, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, DeepseekV2ForCausalLM, DeepseekV2ForSequenceClassification, DeepseekV2Model
|
||||
from transformers.models.deepseek_v2.modeling_deepseek_v2 import DeepseekV2RotaryEmbedding
|
||||
|
||||
|
||||
class DeepseekV2ModelTester(CausalLMModelTester):
|
||||
if is_torch_available():
|
||||
config_class = DeepseekV2Config
|
||||
base_model_class = DeepseekV2Model
|
||||
causal_lm_class = DeepseekV2ForCausalLM
|
||||
sequence_class = DeepseekV2ForSequenceClassification
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
n_routed_experts=8,
|
||||
kv_lora_rank=32,
|
||||
q_lora_rank=16,
|
||||
qk_nope_head_dim=64,
|
||||
qk_rope_head_dim=64,
|
||||
):
|
||||
super().__init__(parent=parent)
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
|
||||
|
||||
@require_torch
|
||||
class DeepseekV2ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
DeepseekV2ForCausalLM,
|
||||
DeepseekV2ForSequenceClassification,
|
||||
DeepseekV2Model,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": DeepseekV2Model,
|
||||
"text-classification": DeepseekV2ForSequenceClassification,
|
||||
"text-generation": DeepseekV2ForCausalLM,
|
||||
"zero-shot": DeepseekV2ForSequenceClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False
|
||||
test_torchscript = False
|
||||
model_tester_class = DeepseekV2ModelTester
|
||||
rotary_embedding_layer = DeepseekV2RotaryEmbedding
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = DeepseekV2ForCausalLM if is_torch_available() else None
|
||||
|
||||
def test_model_rope_scaling(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
scaling_factor = 10
|
||||
short_input_length = 10
|
||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = DeepseekV2RotaryEmbedding(config=config).to(torch_device)
|
||||
original_freqs_cis_short = original_rope(x, position_ids_short)
|
||||
original_freqs_cis_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_freqs_cis_short, original_freqs_cis_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
||||
linear_scaling_rope = DeepseekV2RotaryEmbedding(config=config).to(torch_device)
|
||||
linear_freqs_cis_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_freqs_cis_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_freqs_cis_short, linear_freqs_cis_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
# with scaling_factor (or that `inv_freq` decreases)
|
||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
||||
ntk_scaling_rope = DeepseekV2RotaryEmbedding(config=config).to(torch_device)
|
||||
ntk_freqs_cis_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_freqs_cis_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_freqs_cis_short, original_freqs_cis_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_freqs_cis_long, original_freqs_cis_long)
|
||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||
|
||||
# Sanity check Yarn RoPE scaling
|
||||
# Scaling should be over the entire input
|
||||
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
||||
yarn_scaling_rope = DeepseekV2RotaryEmbedding(config=config).to(torch_device)
|
||||
yarn_freqs_cis_short = yarn_scaling_rope(x, position_ids_short)
|
||||
yarn_freqs_cis_long = yarn_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(yarn_freqs_cis_short, yarn_freqs_cis_long[:, :short_input_length, :])
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_freqs_cis_short, original_freqs_cis_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_freqs_cis_long, original_freqs_cis_long)
|
||||
|
||||
def test_past_key_values_format(self):
|
||||
"""
|
||||
Overwriting to pass the expected cache shapes (Deepseek-V3 uses MLA so the cache shapes are non-standard)
|
||||
"""
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
# difference: last dim
|
||||
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
v_embed_dim = config.v_head_dim
|
||||
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
|
||||
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
|
||||
# build the full cache shapes
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
all_cache_shapes = [
|
||||
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
|
||||
]
|
||||
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
|
||||
|
||||
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||
"""Needs to be overriden as deepseek has special MLA cache format (though we don't really use the MLA)"""
|
||||
self.assertIsInstance(decoder_past_key_values, Cache)
|
||||
|
||||
# (batch, head, seq_length, head_features)
|
||||
expected_common_shape = (
|
||||
batch_size,
|
||||
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||
cache_length,
|
||||
)
|
||||
expected_key_shape = expected_common_shape + (config.qk_nope_head_dim + config.qk_rope_head_dim,)
|
||||
expected_value_shape = expected_common_shape + (config.v_head_dim,)
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_key_shape] * len(decoder_past_key_values.key_cache),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_value_shape] * len(decoder_past_key_values.value_cache),
|
||||
)
|
||||
|
||||
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Dynamic control flow in MoE")
|
||||
def test_torch_compile_for_training(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
@require_torch_accelerator
|
||||
class DeepseekV2IntegrationTest(unittest.TestCase):
|
||||
def test_deepseek_v2_lite(self):
|
||||
EXPECTED_TEXT = ['An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors.\n\nAttention functions are used in a variety of applications, including natural language processing, computer vision, and reinforcement learning.\n\nThe attention function is a function that takes a query and a set of key-value pairs as input and outputs a vector'] # fmt: skip
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V2-Lite")
|
||||
model = DeepseekV2ForCausalLM.from_pretrained(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
device_map=torch_device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
|
||||
input_text = [
|
||||
"An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors." # fmt: skip
|
||||
]
|
||||
model_inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=50, do_sample=False)
|
||||
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(generated_text, EXPECTED_TEXT)
|
||||
|
||||
def test_logits_eager(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
model = DeepseekV2ForCausalLM.from_pretrained(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
device_map=torch_device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
attn_implementation="eager",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(torch.tensor([input_ids]).to(torch_device))
|
||||
|
||||
EXPECTED_MEAN = torch.tensor([[-6.1232, -5.0952, -4.4493, -2.6536, -2.0608, -2.3991, -3.8013, -2.8681]], device=torch_device) # fmt: skip
|
||||
torch.testing.assert_close(out.logits.float().mean(-1), EXPECTED_MEAN, atol=1e-3, rtol=1e-3)
|
||||
|
||||
EXPECTED_SLICE = torch.tensor([-1.2500, -0.9961, -0.0194, -3.1562, 1.2812, -2.7656, -0.8438, -3.0469, -2.7812, -0.6328, -0.4160, -1.9688, -2.4219, -1.0391, -3.8906], device=torch_device) # fmt: skip
|
||||
torch.testing.assert_close(out.logits[0, 0, :15].float(), EXPECTED_SLICE, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_batch_fa2(self):
|
||||
EXPECTED_TEXT = [
|
||||
"Simply put, the theory of relativity states that \nthe laws of physics are the same for all observers, regardless of their \nrelative motion.\nThe theory of relativity is a theory of space, time, and gravity.\nThe theory of", # fmt: skip
|
||||
"My favorite all time favorite condiment is ketchup. I love ketchup. I love ketchup on my hot dogs, hamburgers, french fries, and even on my eggs. I love ketchup. I love ketchup so much that I", # fmt: skip
|
||||
]
|
||||
|
||||
prompts = [
|
||||
"Simply put, the theory of relativity states that ",
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"deepseek-ai/DeepSeek-V2-Lite", pad_token="</s>", padding_side="right"
|
||||
)
|
||||
|
||||
model = DeepseekV2ForCausalLM.from_pretrained(
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
device_map=torch_device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
||||
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT, generated_text)
|
Loading…
Reference in New Issue
Block a user