mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add diffllama (#34083)
* first adding diffllama * add Diff Attention and other but still with errors * complate make attention Diff-Attention * fix some bugs which may be caused by transformer-cli while adding model * fix a bug caused by forgetting KV cache... * Update src/transformers/models/diffllama/modeling_diffllama.py You don't need to divide by 2 if we use same number of attention heads as llama. instead you can just split in forward. Co-authored-by: Minho Ryu <ryumin93@gmail.com> * Update src/transformers/models/diffllama/modeling_diffllama.py fit to changeing "num_heads // 2" place Co-authored-by: Minho Ryu <ryumin93@gmail.com> * Update src/transformers/models/diffllama/modeling_diffllama.py new codes are more meaningful than before Co-authored-by: Minho Ryu <ryumin93@gmail.com> * Update src/transformers/models/diffllama/modeling_diffllama.py new codes are more meaningful than before Co-authored-by: Minho Ryu <ryumin93@gmail.com> * Update src/transformers/models/diffllama/modeling_diffllama.py fit to changeing "num_heads // 2" place Co-authored-by: Minho Ryu <ryumin93@gmail.com> * Update src/transformers/models/diffllama/modeling_diffllama.py fix 2times divide by sqrt(self.head_dim) Co-authored-by: Minho Ryu <ryumin93@gmail.com> * Update src/transformers/models/diffllama/modeling_diffllama.py fix 2times divide by sqrt(self.head_dim) Co-authored-by: Minho Ryu <ryumin93@gmail.com> * Update src/transformers/models/diffllama/modeling_diffllama.py fit to changeing "num_heads // 2" place. and more visible Co-authored-by: Minho Ryu <ryumin93@gmail.com> * I found Attention missed implemented from paper still one072544a3b
. * re-implemented * adding groupnorm Co-authored-by: Minho Ryu <ryumin93@gmail.com> * align with transformers code style Co-authored-by: Minho Ryu <ryumin93@gmail.com> * fix typo Co-authored-by: Minho Ryu <ryumin93@gmail.com> * adding groupnorm Co-authored-by: Minho Ryu <ryumin93@gmail.com> * change SdpaAttention to DiffSdpaAttention Co-authored-by: Minho Ryu <ryumin93@gmail.com> * fix bug * Update src/transformers/models/diffllama/modeling_diffllama.py resolve "not same outputs" problem Co-authored-by: Minho Ryu <ryumin93@gmail.com> * fix bugs of places of "GroupNorm with scale" and etc * Revert "fix bugs of places of "GroupNorm with scale" and etc" This reverts commit26307d92f6
. * simplify multiple of attention (matmul) operations into one by repeating value_states Co-authored-by: Minho Ryu <ryumin93@gmail.com> * simplify multiple of attention (matmul) operations into one by repeating value_states Co-authored-by: Minho Ryu <ryumin93@gmail.com> * simplify multiple of attention (matmul) operations into one by repeating value_states Co-authored-by: Minho Ryu <ryumin93@gmail.com> * remove missed type * add diffllama model_doc * apply make style/quality * apply review comment about model * apply review comment about test * place diffllama alphabetically on the src/transformers/__init__.py * fix forgot code * Supports parameters that are not initialized with standard deviation 0 in the conventional method * add DiffLlamaConfig to CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK on utils/check_config_docstrings.py * remove unused property of config * add to supported model list * add to spda supported model list * fix copyright, remove pretraining_tensor_parallel, and modify for initialization test * remove unused import and etc. * empty commit * empty commit * empty commit * apply modular transformers but with bugs * revert prev commit * create src/transformers/model/diffllama/modular_diffllama.py * run utils/modular_model_converter.py * empty commit * leaner modular diffllama * remove more and more in modular_diffllama.pt * remove more and more in modular_diffllama.pt * resolve missing docstring entries * force reset * convert modular --------- Co-authored-by: Minho Ryu <ryumin93@gmail.com>
This commit is contained in:
parent
ed73ae210b
commit
96bf3d6cc5
@ -386,6 +386,8 @@
|
||||
title: DeBERTa-v2
|
||||
- local: model_doc/dialogpt
|
||||
title: DialoGPT
|
||||
- local: model_doc/diffllama
|
||||
title: DiffLlama
|
||||
- local: model_doc/distilbert
|
||||
title: DistilBERT
|
||||
- local: model_doc/dpr
|
||||
|
@ -125,6 +125,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [DETA](model_doc/deta) | ✅ | ❌ | ❌ |
|
||||
| [DETR](model_doc/detr) | ✅ | ❌ | ❌ |
|
||||
| [DialoGPT](model_doc/dialogpt) | ✅ | ✅ | ✅ |
|
||||
| [DiffLlama](model_doc/diffllama) | ✅ | ❌ | ❌ |
|
||||
| [DiNAT](model_doc/dinat) | ✅ | ❌ | ❌ |
|
||||
| [DINOv2](model_doc/dinov2) | ✅ | ❌ | ✅ |
|
||||
| [DINOv2 with Registers](model_doc/dinov2_with_registers) | ✅ | ❌ | ❌ |
|
||||
|
59
docs/source/en/model_doc/diffllama.md
Normal file
59
docs/source/en/model_doc/diffllama.md
Normal file
@ -0,0 +1,59 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# DiffLlama
|
||||
|
||||
## Overview
|
||||
|
||||
The DiffLlama model was proposed in [Differential Transformer](https://arxiv.org/abs/2410.05258) by Kazuma Matsumoto and .
|
||||
This model is combine Llama model and Differential Transformer's Attention.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance large language models.*
|
||||
|
||||
### Usage tips
|
||||
The hyperparameters of this model is the same as Llama model.
|
||||
|
||||
|
||||
## DiffLlamaConfig
|
||||
|
||||
[[autodoc]] DiffLlamaConfig
|
||||
|
||||
## DiffLlamaModel
|
||||
|
||||
[[autodoc]] DiffLlamaModel
|
||||
- forward
|
||||
|
||||
## DiffLlamaForCausalLM
|
||||
|
||||
[[autodoc]] DiffLlamaForCausalLM
|
||||
- forward
|
||||
|
||||
## DiffLlamaForSequenceClassification
|
||||
|
||||
[[autodoc]] DiffLlamaForSequenceClassification
|
||||
- forward
|
||||
|
||||
## DiffLlamaForQuestionAnswering
|
||||
|
||||
[[autodoc]] DiffLlamaForQuestionAnswering
|
||||
- forward
|
||||
|
||||
## DiffLlamaForTokenClassification
|
||||
|
||||
[[autodoc]] DiffLlamaForTokenClassification
|
||||
- forward
|
@ -47,6 +47,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Cohere2](https://huggingface.co/docs/transformers/model_doc/cohere2#transformers.Cohere2Model)
|
||||
* [GLM](https://huggingface.co/docs/transformers/model_doc/glm#transformers.GLMModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
|
||||
@ -237,6 +238,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [data2vec_vision](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecVisionModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel)
|
||||
* [DiffLlama](https://huggingface.co/docs/transformers/model_doc/diffllama#transformers.DiffLlamaModel)
|
||||
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
|
||||
* [Dinov2_with_registers](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
|
@ -402,6 +402,7 @@ _import_structure = {
|
||||
"models.depth_anything": ["DepthAnythingConfig"],
|
||||
"models.detr": ["DetrConfig"],
|
||||
"models.dialogpt": [],
|
||||
"models.diffllama": ["DiffLlamaConfig"],
|
||||
"models.dinat": ["DinatConfig"],
|
||||
"models.dinov2": ["Dinov2Config"],
|
||||
"models.dinov2_with_registers": ["Dinov2WithRegistersConfig"],
|
||||
@ -2145,6 +2146,16 @@ else:
|
||||
"DetrPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.diffllama"].extend(
|
||||
[
|
||||
"DiffLlamaForCausalLM",
|
||||
"DiffLlamaForQuestionAnswering",
|
||||
"DiffLlamaForSequenceClassification",
|
||||
"DiffLlamaForTokenClassification",
|
||||
"DiffLlamaModel",
|
||||
"DiffLlamaPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.dinat"].extend(
|
||||
[
|
||||
"DinatBackbone",
|
||||
@ -5369,6 +5380,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.depth_anything import DepthAnythingConfig
|
||||
from .models.detr import DetrConfig
|
||||
from .models.diffllama import DiffLlamaConfig
|
||||
from .models.dinat import DinatConfig
|
||||
from .models.dinov2 import Dinov2Config
|
||||
from .models.dinov2_with_registers import Dinov2WithRegistersConfig
|
||||
@ -7017,6 +7029,14 @@ if TYPE_CHECKING:
|
||||
DetrModel,
|
||||
DetrPreTrainedModel,
|
||||
)
|
||||
from .models.diffllama import (
|
||||
DiffLlamaForCausalLM,
|
||||
DiffLlamaForQuestionAnswering,
|
||||
DiffLlamaForSequenceClassification,
|
||||
DiffLlamaForTokenClassification,
|
||||
DiffLlamaModel,
|
||||
DiffLlamaPreTrainedModel,
|
||||
)
|
||||
from .models.dinat import (
|
||||
DinatBackbone,
|
||||
DinatForImageClassification,
|
||||
|
@ -75,6 +75,7 @@ from . import (
|
||||
depth_anything,
|
||||
detr,
|
||||
dialogpt,
|
||||
diffllama,
|
||||
dinat,
|
||||
dinov2,
|
||||
dinov2_with_registers,
|
||||
|
@ -92,6 +92,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("depth_anything", "DepthAnythingConfig"),
|
||||
("deta", "DetaConfig"),
|
||||
("detr", "DetrConfig"),
|
||||
("diffllama", "DiffLlamaConfig"),
|
||||
("dinat", "DinatConfig"),
|
||||
("dinov2", "Dinov2Config"),
|
||||
("dinov2_with_registers", "Dinov2WithRegistersConfig"),
|
||||
@ -403,6 +404,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("deta", "DETA"),
|
||||
("detr", "DETR"),
|
||||
("dialogpt", "DialoGPT"),
|
||||
("diffllama", "DiffLlama"),
|
||||
("dinat", "DiNAT"),
|
||||
("dinov2", "DINOv2"),
|
||||
("dinov2_with_registers", "DINOv2 with Registers"),
|
||||
|
@ -90,6 +90,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("deit", "DeiTModel"),
|
||||
("deta", "DetaModel"),
|
||||
("detr", "DetrModel"),
|
||||
("diffllama", "DiffLlamaModel"),
|
||||
("dinat", "DinatModel"),
|
||||
("dinov2", "Dinov2Model"),
|
||||
("dinov2_with_registers", "Dinov2WithRegistersModel"),
|
||||
@ -493,6 +494,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("data2vec-text", "Data2VecTextForCausalLM"),
|
||||
("dbrx", "DbrxForCausalLM"),
|
||||
("diffllama", "DiffLlamaForCausalLM"),
|
||||
("electra", "ElectraForCausalLM"),
|
||||
("ernie", "ErnieForCausalLM"),
|
||||
("falcon", "FalconForCausalLM"),
|
||||
@ -961,6 +963,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-text", "Data2VecTextForSequenceClassification"),
|
||||
("deberta", "DebertaForSequenceClassification"),
|
||||
("deberta-v2", "DebertaV2ForSequenceClassification"),
|
||||
("diffllama", "DiffLlamaForSequenceClassification"),
|
||||
("distilbert", "DistilBertForSequenceClassification"),
|
||||
("electra", "ElectraForSequenceClassification"),
|
||||
("ernie", "ErnieForSequenceClassification"),
|
||||
@ -1056,6 +1059,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-text", "Data2VecTextForQuestionAnswering"),
|
||||
("deberta", "DebertaForQuestionAnswering"),
|
||||
("deberta-v2", "DebertaV2ForQuestionAnswering"),
|
||||
("diffllama", "DiffLlamaForQuestionAnswering"),
|
||||
("distilbert", "DistilBertForQuestionAnswering"),
|
||||
("electra", "ElectraForQuestionAnswering"),
|
||||
("ernie", "ErnieForQuestionAnswering"),
|
||||
@ -1153,6 +1157,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-text", "Data2VecTextForTokenClassification"),
|
||||
("deberta", "DebertaForTokenClassification"),
|
||||
("deberta-v2", "DebertaV2ForTokenClassification"),
|
||||
("diffllama", "DiffLlamaForTokenClassification"),
|
||||
("distilbert", "DistilBertForTokenClassification"),
|
||||
("electra", "ElectraForTokenClassification"),
|
||||
("ernie", "ErnieForTokenClassification"),
|
||||
|
@ -170,6 +170,13 @@ else:
|
||||
"DebertaV2TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
(
|
||||
"diffllama",
|
||||
(
|
||||
"LlamaTokenizer" if is_sentencepiece_available() else None,
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"dpr",
|
||||
|
27
src/transformers/models/diffllama/__init__.py
Normal file
27
src/transformers/models/diffllama/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_diffllama import *
|
||||
from .modeling_diffllama import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
199
src/transformers/models/diffllama/configuration_diffllama.py
Normal file
199
src/transformers/models/diffllama/configuration_diffllama.py
Normal file
@ -0,0 +1,199 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on Llama implementations in this library and Microsoft's
|
||||
# Differential Transformer implementations.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""DiffLlama model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class DiffLlamaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DiffLlamaModel`]. It is used to instantiate an DiffLlama
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults
|
||||
will yield a similar configuration to that of the [kajuma/DiffLlama-0.3B-handcut](https://huggingface.co/kajuma/DiffLlama-0.3B-handcut).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the DiffLlama model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`DiffLlamaModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 16):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'diffllama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'diffllama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'diffllama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'diffllama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
lambda_std_dev (`float`, *optional*, defaults to 0.1):
|
||||
The standard deviation for initialization of parameter lambda in attention layer.
|
||||
head_dim (`int`, *optional*):
|
||||
The attention head dimension. If None, it will default to hidden_size // num_heads
|
||||
|
||||
```python
|
||||
>>> from transformers import DiffLlamaModel, DiffLlamaConfig
|
||||
|
||||
>>> # Initializing a DiffLlama diffllama-7b style configuration
|
||||
>>> configuration = DiffLlamaConfig()
|
||||
|
||||
>>> # Initializing a model from the diffllama-7b style configuration
|
||||
>>> model = DiffLlamaModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "diffllama"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=2048,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=16,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
lambda_std_dev=0.1,
|
||||
head_dim=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.lambda_std_dev = lambda_std_dev
|
||||
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DiffLlamaConfig"]
|
1426
src/transformers/models/diffllama/modeling_diffllama.py
Normal file
1426
src/transformers/models/diffllama/modeling_diffllama.py
Normal file
File diff suppressed because it is too large
Load Diff
464
src/transformers/models/diffllama/modular_diffllama.py
Normal file
464
src/transformers/models/diffllama/modular_diffllama.py
Normal file
@ -0,0 +1,464 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on Llama implementations in this library and Microsoft's
|
||||
# Differential Transformer implementations.
|
||||
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...utils import (
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
)
|
||||
from ..gemma.modeling_gemma import GemmaForCausalLM
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForQuestionAnswering,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from ..mistral.modeling_mistral import MistralMLP
|
||||
from .configuration_diffllama import DiffLlamaConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
|
||||
_CONFIG_FOR_DOC = "DiffLlamaConfig"
|
||||
|
||||
|
||||
class DiffLlamaMLP(MistralMLP):
|
||||
pass
|
||||
|
||||
|
||||
def lambda_init_fn(layer_idx):
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
|
||||
|
||||
|
||||
class DiffLlamaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
# under this are not used
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||
|
||||
self.lambda_init = lambda_init_fn(layer_idx)
|
||||
self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
|
||||
self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
|
||||
self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
|
||||
self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
|
||||
self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, target_len, _ = hidden_states.size()
|
||||
q_len = target_len
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
|
||||
value_states = value_states.repeat(1, 2, 1, 1)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
|
||||
query_states.dtype
|
||||
)
|
||||
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
|
||||
query_states.dtype
|
||||
)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
|
||||
|
||||
attn_output = attn_output1 - lambda_full * attn_output2
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DiffLlamaFlashAttention2(DiffLlamaAttention):
|
||||
"""
|
||||
DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
||||
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
||||
)
|
||||
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# Flash attention requires the input to have the shape
|
||||
# batch_size x seq_length x head_dim x hidden_dim
|
||||
# therefore we just need to keep the original shape
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||
# in fp32. (DiffLlamaRMSNorm handles it correctly)
|
||||
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
|
||||
value_states1 = value_states1.repeat(1, 1, 2, 1)
|
||||
value_states2 = value_states2.repeat(1, 1, 2, 1)
|
||||
|
||||
attn_output1 = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states1,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
|
||||
attn_output2 = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states2,
|
||||
attention_mask,
|
||||
q_len,
|
||||
position_ids=position_ids,
|
||||
dropout=dropout_rate,
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
|
||||
attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
|
||||
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
|
||||
|
||||
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
|
||||
query_states.dtype
|
||||
)
|
||||
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
|
||||
query_states.dtype
|
||||
)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
attn_output = attn_output1 - lambda_full * attn_output2
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DiffLlamaSdpaAttention(DiffLlamaAttention):
|
||||
"""
|
||||
DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
`DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
||||
SDPA API.
|
||||
"""
|
||||
|
||||
# Adapted from DiffLlamaAttention.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
logger.warning_once(
|
||||
"DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
||||
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
return super().forward(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
|
||||
value_states = value_states.repeat(1, 2, 1, 1)
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
|
||||
|
||||
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
|
||||
query_states.dtype
|
||||
)
|
||||
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
|
||||
query_states.dtype
|
||||
)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
attn_output = attn_output1 - lambda_full * attn_output2
|
||||
attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
DIFFLLAMA_ATTENTION_CLASSES = {
|
||||
"eager": DiffLlamaAttention,
|
||||
"flash_attention_2": DiffLlamaFlashAttention2,
|
||||
"sdpa": DiffLlamaSdpaAttention,
|
||||
}
|
||||
|
||||
|
||||
class DiffLlamaDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(self, config: DiffLlamaConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
||||
|
||||
|
||||
class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
|
||||
pass
|
||||
|
||||
|
||||
class DiffLlamaModel(LlamaModel):
|
||||
pass
|
||||
|
||||
|
||||
class DiffLlamaForCausalLM(GemmaForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
class DiffLlamaForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
|
||||
class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering):
|
||||
pass
|
||||
|
||||
|
||||
class DiffLlamaForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DiffLlamaPreTrainedModel",
|
||||
"DiffLlamaModel", # noqa: F822
|
||||
"DiffLlamaForCausalLM",
|
||||
"DiffLlamaForSequenceClassification",
|
||||
"DiffLlamaForQuestionAnswering",
|
||||
"DiffLlamaForTokenClassification",
|
||||
]
|
@ -3579,6 +3579,48 @@ class DetrPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DiffLlamaForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DiffLlamaForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DiffLlamaForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DiffLlamaForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DiffLlamaModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DiffLlamaPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DinatBackbone(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
0
tests/models/diffllama/__init__.py
Normal file
0
tests/models/diffllama/__init__.py
Normal file
992
tests/models/diffllama/test_modeling_diffllama.py
Normal file
992
tests/models/diffllama/test_modeling_diffllama.py
Normal file
@ -0,0 +1,992 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch DiffLlama model."""
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, DiffLlamaConfig, StaticCache, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
DiffLlamaForCausalLM,
|
||||
DiffLlamaForQuestionAnswering,
|
||||
DiffLlamaForSequenceClassification,
|
||||
DiffLlamaForTokenClassification,
|
||||
DiffLlamaModel,
|
||||
)
|
||||
from transformers.models.diffllama.modeling_diffllama import (
|
||||
DiffLlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class DiffLlamaModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
return DiffLlamaConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DiffLlamaModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = DiffLlamaModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = DiffLlamaForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = DiffLlamaForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class DiffLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
DiffLlamaModel,
|
||||
DiffLlamaForCausalLM,
|
||||
DiffLlamaForSequenceClassification,
|
||||
DiffLlamaForQuestionAnswering,
|
||||
DiffLlamaForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (DiffLlamaForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": DiffLlamaModel,
|
||||
"text-classification": DiffLlamaForSequenceClassification,
|
||||
"text-generation": DiffLlamaForCausalLM,
|
||||
"zero-shot": DiffLlamaForSequenceClassification,
|
||||
"question-answering": DiffLlamaForQuestionAnswering,
|
||||
"token-classification": DiffLlamaForTokenClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = DiffLlamaForCausalLM if is_torch_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DiffLlamaModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DiffLlamaConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_diffllama_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = DiffLlamaForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_diffllama_sequence_classification_model_for_single_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "single_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
model = DiffLlamaForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_diffllama_sequence_classification_model_for_multi_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
config.problem_type = "multi_label_classification"
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
sequence_labels = ids_tensor(
|
||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||
).to(torch.float)
|
||||
model = DiffLlamaForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_diffllama_token_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
model = DiffLlamaForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
self.assertEqual(
|
||||
result.logits.shape,
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
)
|
||||
|
||||
@unittest.skip(reason="DiffLlama buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
|
||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
original_model = DiffLlamaModel(config)
|
||||
original_model.to(torch_device)
|
||||
original_model.eval()
|
||||
original_short_output = original_model(short_input).last_hidden_state
|
||||
original_long_output = original_model(long_input).last_hidden_state
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||
scaled_model = DiffLlamaModel(config)
|
||||
scaled_model.to(torch_device)
|
||||
scaled_model.eval()
|
||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||
|
||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||
# maximum sequence length, so the outputs for the short input should match.
|
||||
if scaling_type == "dynamic":
|
||||
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
def test_model_rope_scaling(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
scaling_factor = 10
|
||||
short_input_length = 10
|
||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = DiffLlamaRotaryEmbedding(config=config).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
||||
linear_scaling_rope = DiffLlamaRotaryEmbedding(config=config).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
# with scaling_factor (or that `inv_freq` decreases)
|
||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
||||
ntk_scaling_rope = DiffLlamaRotaryEmbedding(config=config).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||
|
||||
# Sanity check Yarn RoPE scaling
|
||||
# Scaling should be over the entire input
|
||||
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
||||
yarn_scaling_rope = DiffLlamaRotaryEmbedding(config=config).to(torch_device)
|
||||
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
||||
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||
|
||||
def test_model_loading_old_rope_configs(self):
|
||||
def _reinitialize_config(base_config, new_kwargs):
|
||||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
||||
# steps.
|
||||
base_config_dict = base_config.to_dict()
|
||||
new_config = DiffLlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
|
||||
return new_config
|
||||
|
||||
# from untouched config -> ✅
|
||||
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_model = DiffLlamaForCausalLM(base_config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the expected rope configuration -> ✅
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
|
||||
original_model = DiffLlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
|
||||
original_model = DiffLlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
|
||||
)
|
||||
self.assertTrue(config.rope_scaling["type"] == "linear")
|
||||
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
|
||||
original_model = DiffLlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
|
||||
original_model = DiffLlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("factor field", logs.output[0])
|
||||
|
||||
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
|
||||
)
|
||||
original_model = DiffLlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("Unrecognized keys", logs.output[0])
|
||||
|
||||
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
|
||||
with self.assertRaises(KeyError):
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@pytest.mark.flash_attn_test
|
||||
@require_read_token
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = DiffLlamaForCausalLM.from_pretrained(
|
||||
"kajuma/DiffLlama-0.3B-handcut",
|
||||
load_in_4bit=True,
|
||||
device_map={"": 0},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("kajuma/DiffLlama-0.3B-handcut")
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence"]
|
||||
|
||||
tokenizer.padding_side = "right"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
||||
|
||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = DiffLlamaForCausalLM.from_pretrained(
|
||||
"kajuma/DiffLlama-0.3B-handcut",
|
||||
load_in_4bit=True,
|
||||
device_map={"": 0},
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
||||
|
||||
self.assertListEqual(output_native, output_fa_2)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@pytest.mark.flash_attn_test
|
||||
def test_use_flash_attention_2_true(self):
|
||||
"""
|
||||
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(config)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
new_model = DiffLlamaForCausalLM.from_pretrained(
|
||||
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
|
||||
|
||||
has_flash = False
|
||||
for name, submodule in new_model.named_modules():
|
||||
if "FlashAttention" in submodule.__class__.__name__:
|
||||
has_flash = True
|
||||
break
|
||||
if not has_flash:
|
||||
raise ValueError("The flash model should have flash attention layers")
|
||||
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
max_new_tokens = 30
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("kajuma/DiffLlama-0.3B-handcut")
|
||||
|
||||
model_sdpa = DiffLlamaForCausalLM.from_pretrained(
|
||||
"kajuma/DiffLlama-0.3B-handcut",
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = DiffLlamaForCausalLM.from_pretrained(
|
||||
"kajuma/DiffLlama-0.3B-handcut",
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa:
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
texts = [
|
||||
"hi here's a longer context, getting longer and",
|
||||
"Hello this is a very long sentence my friend, very long for real",
|
||||
"Today I am in Paris and",
|
||||
]
|
||||
|
||||
for padding_side in ["left", "right"]:
|
||||
tokenizer.padding_side = padding_side
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||
|
||||
with self.subTest(f"{padding_side}"):
|
||||
torch.testing.assert_close(
|
||||
res_eager,
|
||||
res_sdpa,
|
||||
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
||||
)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class DiffLlamaIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||
"theory of relativ",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
||||
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
]
|
||||
|
||||
prompts = [
|
||||
"Simply put, the theory of relativity states that ",
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"kajuma/DiffLlama-0.3B-handcut", pad_token="</s>", padding_side="right"
|
||||
)
|
||||
model = DiffLlamaForCausalLM.from_pretrained(
|
||||
"kajuma/DiffLlama-0.3B-handcut", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
# Dynamic Cache
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||
|
||||
# Static Cache
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
# Static Cache + compile
|
||||
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class Mask4DTestHard(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def setUp(self):
|
||||
model_name = "kajuma/DiffLlama-0.3B-handcut"
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = DiffLlamaForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def get_test_data(self):
|
||||
template = "my favorite {}"
|
||||
items = ("pet is a", "artist plays a", "name is L") # same number of tokens in each item
|
||||
|
||||
batch_separate = [template.format(x) for x in items] # 3 separate lines
|
||||
batch_shared_prefix = template.format(" ".join(items)) # 1 line with options concatenated
|
||||
|
||||
input_ids = self.tokenizer(batch_separate, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids_shared_prefix = self.tokenizer(batch_shared_prefix, return_tensors="pt").input_ids.to(torch_device)
|
||||
|
||||
mask_shared_prefix = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1],
|
||||
]
|
||||
]
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
position_ids = torch.arange(input_ids.shape[1]).tile(input_ids.shape[0], 1).to(torch_device)
|
||||
|
||||
# building custom positions ids based on custom mask
|
||||
position_ids_shared_prefix = (mask_shared_prefix.sum(dim=-1) - 1).reshape(1, -1)
|
||||
# effectively: position_ids_shared_prefix = torch.tensor([[0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5]]).to(device)
|
||||
|
||||
# inverting the mask
|
||||
min_dtype = torch.finfo(self.model_dtype).min
|
||||
mask_shared_prefix = (mask_shared_prefix.eq(0.0)).to(dtype=self.model_dtype) * min_dtype
|
||||
|
||||
return input_ids, position_ids, input_ids_shared_prefix, mask_shared_prefix, position_ids_shared_prefix
|
||||
|
||||
def test_stacked_causal_mask(self):
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_shared_prefix = self.model.forward(
|
||||
input_ids_shared_prefix, attention_mask=mask_shared_prefix, position_ids=position_ids_shared_prefix
|
||||
).logits
|
||||
logits_shared_prefix_last = logits_shared_prefix[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
||||
] # last three tokens
|
||||
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded, decoded_shared_prefix)
|
||||
|
||||
def test_partial_stacked_causal_mask(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
||||
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# 2 forward runs with custom 4D masks
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_ids_shared_prefix[:, :part_a]
|
||||
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
||||
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
||||
|
||||
outs_1a = self.model.forward(input_1a, attention_mask=mask_1a, position_ids=position_ids_1a)
|
||||
past_key_values_a = outs_1a["past_key_values"]
|
||||
|
||||
# Case 1: we pass a 4D attention mask regarding the current sequence length (i.e. [..., seq_len, full_len])
|
||||
input_1b = input_ids_shared_prefix[:, part_a:]
|
||||
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
||||
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
||||
outs_1b = self.model.forward(
|
||||
input_1b,
|
||||
attention_mask=mask_1b,
|
||||
position_ids=position_ids_1b,
|
||||
past_key_values=past_key_values_a,
|
||||
)
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
||||
]
|
||||
]
|
||||
self.assertEqual(decoded, decoded_1b)
|
||||
|
||||
def test_stacked_causal_mask_static_cache(self):
|
||||
"""same as above but with StaticCache"""
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# upgrade the model with StaticCache
|
||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||
past_key_values = StaticCache(
|
||||
config=self.model.config,
|
||||
batch_size=1,
|
||||
max_cache_len=max_cache_len,
|
||||
device=torch_device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
padded_attention_mask = torch.nn.functional.pad(
|
||||
input=mask_shared_prefix,
|
||||
pad=(0, max_cache_len - mask_shared_prefix.shape[-1]),
|
||||
mode="constant",
|
||||
value=torch.finfo(self.model_dtype).min,
|
||||
)
|
||||
|
||||
# single forward run with 4D custom mask
|
||||
logits_shared_prefix = self.model.forward(
|
||||
input_ids_shared_prefix,
|
||||
attention_mask=padded_attention_mask,
|
||||
position_ids=position_ids_shared_prefix,
|
||||
cache_position=torch.arange(input_ids_shared_prefix.shape[-1], device=torch_device),
|
||||
past_key_values=past_key_values,
|
||||
).logits
|
||||
logits_shared_prefix_last = logits_shared_prefix[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1], :
|
||||
] # last three tokens
|
||||
decoded_shared_prefix = [self.tokenizer.decode(t) for t in logits_shared_prefix_last.argmax(dim=-1)]
|
||||
|
||||
self.assertEqual(decoded, decoded_shared_prefix)
|
||||
|
||||
def test_partial_stacked_causal_mask_static_cache(self):
|
||||
# Same as the test above, but the input is passed in two groups. It tests that we can pass partial 4D attention masks
|
||||
# we pass a 4D attention mask shaped [..., seq_len, full_static_cache_len])
|
||||
(
|
||||
input_ids,
|
||||
position_ids,
|
||||
input_ids_shared_prefix,
|
||||
mask_shared_prefix,
|
||||
position_ids_shared_prefix,
|
||||
) = self.get_test_data()
|
||||
|
||||
# regular batch
|
||||
logits = self.model.forward(input_ids, position_ids=position_ids).logits
|
||||
logits_last = logits[:, -1, :] # last tokens in each batch line
|
||||
decoded = [self.tokenizer.decode(t) for t in logits_last.argmax(dim=-1)]
|
||||
|
||||
# upgrade the model with StaticCache
|
||||
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
|
||||
past_key_values = StaticCache(
|
||||
config=self.model.config,
|
||||
batch_size=1,
|
||||
max_cache_len=max_cache_len,
|
||||
device=torch_device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
# forward run for the first part of input
|
||||
part_a = 3 # split point
|
||||
|
||||
input_1a = input_ids_shared_prefix[:, :part_a]
|
||||
position_ids_1a = position_ids_shared_prefix[:, :part_a]
|
||||
mask_1a = mask_shared_prefix[:, :, :part_a, :part_a]
|
||||
|
||||
padded_mask_1a = torch.nn.functional.pad(
|
||||
input=mask_1a,
|
||||
pad=(0, max_cache_len - mask_1a.shape[-1]),
|
||||
mode="constant",
|
||||
value=torch.finfo(self.model_dtype).min,
|
||||
)
|
||||
|
||||
_ = self.model.forward(
|
||||
input_1a,
|
||||
attention_mask=padded_mask_1a,
|
||||
position_ids=position_ids_1a,
|
||||
cache_position=torch.arange(part_a, device=torch_device),
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
# forward run for the second part of input
|
||||
input_1b = input_ids_shared_prefix[:, part_a:]
|
||||
position_ids_1b = position_ids_shared_prefix[:, part_a:]
|
||||
mask_1b = mask_shared_prefix[:, :, part_a:, :]
|
||||
|
||||
padded_mask_1b = torch.nn.functional.pad(
|
||||
input=mask_1b, pad=(0, max_cache_len - mask_1b.shape[-1]), mode="constant", value=0
|
||||
)
|
||||
|
||||
outs_1b = self.model.forward(
|
||||
input_1b,
|
||||
attention_mask=padded_mask_1b,
|
||||
position_ids=position_ids_1b,
|
||||
cache_position=torch.arange(
|
||||
part_a,
|
||||
input_ids_shared_prefix.shape[-1],
|
||||
device=torch_device,
|
||||
),
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
decoded_1b = [
|
||||
self.tokenizer.decode(t)
|
||||
for t in outs_1b.logits.argmax(-1)[
|
||||
0, torch.where(position_ids_shared_prefix == position_ids_shared_prefix.max())[1] - part_a
|
||||
]
|
||||
]
|
||||
self.assertEqual(decoded, decoded_1b)
|
Loading…
Reference in New Issue
Block a user