mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge d88b28348c
into 37a239ca50
This commit is contained in:
commit
9d7240bc3f
83
docs/source/en/model_doc/plm.md
Normal file
83
docs/source/en/model_doc/plm.md
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
<!--Copyright 2025 The PLM Team and The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# PLM
|
||||||
|
<div class="flex flex-wrap space-x-1">
|
||||||
|
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||||
|
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The PLM model was proposed in [PLM: Efficient Peripheral Language Models Hardware-Co-Designed for Ubiquitous Computing](https://arxiv.org/abs/2503.12167) by PLM-Team.
|
||||||
|
|
||||||
|
### Summary
|
||||||
|
|
||||||
|
The PLM (Peripheral Language Model) series introduces a novel model architecture to peripheral computing by delivering powerful language capabilities within the constraints of resource-limited devices. Through modeling and system co-design strategy, PLM optimizes model performance and fits edge system requirements, PLM employs Multi-head Latent Attention and squared ReLU activation to achieve sparsity, significantly reducing memory footprint and computational demands. Coupled with a meticulously crafted training regimen using curated datasets and a Warmup-Stable-Decay-Constant learning rate scheduler, PLM demonstrates superior performance compared to existing small language models, all while maintaining the lowest activated parameters, making it ideally suited for deployment on diverse peripheral platforms like mobile phones and Raspberry Pis.
|
||||||
|
|
||||||
|
|
||||||
|
## Usage tips
|
||||||
|
|
||||||
|
Ensure your Transformers library version is up-to-date. PLM requires Transformers>=4.51.3 for full support.
|
||||||
|
|
||||||
|
|
||||||
|
`PLM-1.8B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/PLM-Team/PLM-1.8B-Instruct)
|
||||||
|
|
||||||
|
|
||||||
|
In the following, we demonstrate how to use it for inference
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
# Load model and tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Instruct")
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("PLM-Team/PLM-1.8B-Instruct", torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Input text
|
||||||
|
input_text = "Tell me something about reinforcement learning."
|
||||||
|
inputs = tokenizer(input_text, return_tensors="pt")
|
||||||
|
|
||||||
|
# Completion
|
||||||
|
output = model.generate(inputs["input_ids"], max_new_tokens=100)
|
||||||
|
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## PLMConfig
|
||||||
|
|
||||||
|
[[autodoc]] PLMConfig
|
||||||
|
|
||||||
|
## PLMModel
|
||||||
|
|
||||||
|
[[autodoc]] PLMModel
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## PLMForCausalLM
|
||||||
|
|
||||||
|
[[autodoc]] PLMForCausalLM
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## PLMForSequenceClassification
|
||||||
|
|
||||||
|
[[autodoc]] PLMForSequenceClassification
|
||||||
|
- forward
|
||||||
|
|
||||||
|
## PLMForTokenClassification
|
||||||
|
|
||||||
|
[[autodoc]] PLMForTokenClassification
|
||||||
|
- forward
|
@ -243,6 +243,7 @@ if TYPE_CHECKING:
|
|||||||
from .pix2struct import *
|
from .pix2struct import *
|
||||||
from .pixtral import *
|
from .pixtral import *
|
||||||
from .plbart import *
|
from .plbart import *
|
||||||
|
from .plm import *
|
||||||
from .poolformer import *
|
from .poolformer import *
|
||||||
from .pop2piano import *
|
from .pop2piano import *
|
||||||
from .prompt_depth_anything import *
|
from .prompt_depth_anything import *
|
||||||
|
@ -275,6 +275,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("pix2struct", "Pix2StructConfig"),
|
("pix2struct", "Pix2StructConfig"),
|
||||||
("pixtral", "PixtralVisionConfig"),
|
("pixtral", "PixtralVisionConfig"),
|
||||||
("plbart", "PLBartConfig"),
|
("plbart", "PLBartConfig"),
|
||||||
|
("plm", "PLMConfig"),
|
||||||
("poolformer", "PoolFormerConfig"),
|
("poolformer", "PoolFormerConfig"),
|
||||||
("pop2piano", "Pop2PianoConfig"),
|
("pop2piano", "Pop2PianoConfig"),
|
||||||
("prompt_depth_anything", "PromptDepthAnythingConfig"),
|
("prompt_depth_anything", "PromptDepthAnythingConfig"),
|
||||||
@ -672,6 +673,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("pix2struct", "Pix2Struct"),
|
("pix2struct", "Pix2Struct"),
|
||||||
("pixtral", "Pixtral"),
|
("pixtral", "Pixtral"),
|
||||||
("plbart", "PLBart"),
|
("plbart", "PLBart"),
|
||||||
|
("plm", "PLM"),
|
||||||
("poolformer", "PoolFormer"),
|
("poolformer", "PoolFormer"),
|
||||||
("pop2piano", "Pop2Piano"),
|
("pop2piano", "Pop2Piano"),
|
||||||
("prompt_depth_anything", "PromptDepthAnything"),
|
("prompt_depth_anything", "PromptDepthAnything"),
|
||||||
|
@ -262,6 +262,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("phimoe", "PhimoeModel"),
|
("phimoe", "PhimoeModel"),
|
||||||
("pixtral", "PixtralVisionModel"),
|
("pixtral", "PixtralVisionModel"),
|
||||||
("plbart", "PLBartModel"),
|
("plbart", "PLBartModel"),
|
||||||
|
("plm", "PLMModel"),
|
||||||
("poolformer", "PoolFormerModel"),
|
("poolformer", "PoolFormerModel"),
|
||||||
("prophetnet", "ProphetNetModel"),
|
("prophetnet", "ProphetNetModel"),
|
||||||
("pvt", "PvtModel"),
|
("pvt", "PvtModel"),
|
||||||
@ -640,6 +641,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|||||||
("phi4_multimodal", "Phi4MultimodalForCausalLM"),
|
("phi4_multimodal", "Phi4MultimodalForCausalLM"),
|
||||||
("phimoe", "PhimoeForCausalLM"),
|
("phimoe", "PhimoeForCausalLM"),
|
||||||
("plbart", "PLBartForCausalLM"),
|
("plbart", "PLBartForCausalLM"),
|
||||||
|
("plm", "PLMForCausalLM"),
|
||||||
("prophetnet", "ProphetNetForCausalLM"),
|
("prophetnet", "ProphetNetForCausalLM"),
|
||||||
("qdqbert", "QDQBertLMHeadModel"),
|
("qdqbert", "QDQBertLMHeadModel"),
|
||||||
("qwen2", "Qwen2ForCausalLM"),
|
("qwen2", "Qwen2ForCausalLM"),
|
||||||
@ -1161,6 +1163,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("phi3", "Phi3ForSequenceClassification"),
|
("phi3", "Phi3ForSequenceClassification"),
|
||||||
("phimoe", "PhimoeForSequenceClassification"),
|
("phimoe", "PhimoeForSequenceClassification"),
|
||||||
("plbart", "PLBartForSequenceClassification"),
|
("plbart", "PLBartForSequenceClassification"),
|
||||||
|
("plm", "PLMForSequenceClassification"),
|
||||||
("qdqbert", "QDQBertForSequenceClassification"),
|
("qdqbert", "QDQBertForSequenceClassification"),
|
||||||
("qwen2", "Qwen2ForSequenceClassification"),
|
("qwen2", "Qwen2ForSequenceClassification"),
|
||||||
("qwen2_moe", "Qwen2MoeForSequenceClassification"),
|
("qwen2_moe", "Qwen2MoeForSequenceClassification"),
|
||||||
@ -1358,6 +1361,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
("persimmon", "PersimmonForTokenClassification"),
|
("persimmon", "PersimmonForTokenClassification"),
|
||||||
("phi", "PhiForTokenClassification"),
|
("phi", "PhiForTokenClassification"),
|
||||||
("phi3", "Phi3ForTokenClassification"),
|
("phi3", "Phi3ForTokenClassification"),
|
||||||
|
("plm", "PLMForTokenClassification"),
|
||||||
("qdqbert", "QDQBertForTokenClassification"),
|
("qdqbert", "QDQBertForTokenClassification"),
|
||||||
("qwen2", "Qwen2ForTokenClassification"),
|
("qwen2", "Qwen2ForTokenClassification"),
|
||||||
("qwen2_moe", "Qwen2MoeForTokenClassification"),
|
("qwen2_moe", "Qwen2MoeForTokenClassification"),
|
||||||
|
28
src/transformers/models/plm/__init__.py
Normal file
28
src/transformers/models/plm/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The PLM team and the HuggingFace Inc. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_plm import *
|
||||||
|
from .modeling_plm import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
159
src/transformers/models/plm/configuration_plm.py
Normal file
159
src/transformers/models/plm/configuration_plm.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The PLM team and The HuggingFace Inc. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""PLM model configuration"""
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PLMConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`PLMModel`]. It is used to instantiate a
|
||||||
|
PLM model according to the specified arguments, defining the model architecture.
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of
|
||||||
|
PLM-1.8B-Base [PLM-Team/PLM-1.8B-Base](https://huggingface.co/PLM-Team/PLM-1.8B-Base).
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 151936):
|
||||||
|
Vocabulary size of the PLM model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`PLMModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 2048):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||||
|
kv_lora_rank (`int`, *optional*, defaults to 512):
|
||||||
|
q_lora_rank (`int`, *optional*):
|
||||||
|
qk_rope_head_dim (`int`, *optional*, defaults to 64):
|
||||||
|
v_head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||||
|
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||||
|
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||||
|
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||||
|
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model's input and output word embeddings should be tied.
|
||||||
|
rope_theta (`float`, *optional*, defaults to 100000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports normal rope.
|
||||||
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
rope_interleave (`bool`, *optional*, defaults to `True`):
|
||||||
|
```python
|
||||||
|
>>> from transformers import PLMModel, PLMConfig
|
||||||
|
>>> # Initializing a PLM style configuration
|
||||||
|
>>> configuration = PLMConfig()
|
||||||
|
>>> # Initializing a model from the PLM style configuration
|
||||||
|
>>> model = PLMModel(configuration)
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "plm"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=151936,
|
||||||
|
hidden_size=2048,
|
||||||
|
intermediate_size=8192,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=16,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=None,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
hidden_act="relu2",
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=100000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
rope_interleave=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.rope_interleave = rope_interleave
|
||||||
|
self.head_dim = qk_rope_head_dim
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["PLMConfig"]
|
1112
src/transformers/models/plm/modeling_plm.py
Normal file
1112
src/transformers/models/plm/modeling_plm.py
Normal file
File diff suppressed because it is too large
Load Diff
268
src/transformers/models/plm/modular_plm.py
Normal file
268
src/transformers/models/plm/modular_plm.py
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The PLM team and the HuggingFace Inc. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
from ...cache_utils import Cache
|
||||||
|
from ..clip.modeling_clip import CLIPMLP
|
||||||
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
from ...utils import logging
|
||||||
|
from ..llama.modeling_llama import (
|
||||||
|
LlamaDecoderLayer,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
LlamaForSequenceClassification,
|
||||||
|
LlamaForTokenClassification,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaPreTrainedModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
eager_attention_forward,
|
||||||
|
rotate_half,
|
||||||
|
)
|
||||||
|
from .configuration_plm import PLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PLMRMSNorm(LlamaRMSNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PLMRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
r"""
|
||||||
|
TODO let's just use the original freqcis computation to not have the view
|
||||||
|
transpose + reshape! This is not optimized!
|
||||||
|
Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`):
|
||||||
|
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||||
|
used to pass offsetted position ids when working with a KV-cache.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
|
||||||
|
b, h, s, d = q.shape
|
||||||
|
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||||
|
|
||||||
|
b, h, s, d = k.shape
|
||||||
|
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class PLMMLP(CLIPMLP):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PLMAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: PLMConfig, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
self.q_lora_rank = config.q_lora_rank
|
||||||
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
|
self.v_head_dim = config.v_head_dim
|
||||||
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
|
self.qk_head_dim = config.qk_head_dim
|
||||||
|
|
||||||
|
self.is_causal = True
|
||||||
|
if config.q_lora_rank is not None:
|
||||||
|
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||||
|
self.q_a_layernorm = PLMRMSNorm(config.q_lora_rank)
|
||||||
|
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
|
||||||
|
else:
|
||||||
|
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
self.kv_a_layernorm = PLMRMSNorm(self.kv_lora_rank)
|
||||||
|
self.kv_b_proj = nn.Linear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=config.attention_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scaling = self.qk_head_dim ** (-0.5)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
batch_size, seq_length = hidden_states.shape[:-1]
|
||||||
|
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
|
||||||
|
key_shape = (
|
||||||
|
batch_size,
|
||||||
|
seq_length,
|
||||||
|
-1,
|
||||||
|
self.qk_nope_head_dim + self.v_head_dim,
|
||||||
|
)
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
q_states = (
|
||||||
|
self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_states = self.q_proj(hidden_states).view(query_shape).transpose(1, 2)
|
||||||
|
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
|
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
|
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
|
||||||
|
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
|
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
if self.config.rope_interleave: # support using interleaved weights for efficiency
|
||||||
|
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
|
||||||
|
else:
|
||||||
|
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
|
||||||
|
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
|
||||||
|
|
||||||
|
query_states = torch.cat((q_pass, q_rot), dim=-1)
|
||||||
|
key_states = torch.cat((k_pass, k_rot), dim=-1)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||||
|
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
|
||||||
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||||
|
logger.warning_once(
|
||||||
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||||
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
|
scaling=self.scaling,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||||
|
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class PLMDecoderLayer(LlamaDecoderLayer):
|
||||||
|
def __init__(self, config: PLMConfig, layer_idx: int):
|
||||||
|
super().__init__(config, layer_idx)
|
||||||
|
self.self_attn = PLMAttention(config=config, layer_idx=layer_idx)
|
||||||
|
self.mlp = PLMMLP(config)
|
||||||
|
|
||||||
|
|
||||||
|
class PLMPreTrainedModel(LlamaPreTrainedModel):
|
||||||
|
def _init_weights(self, module):
|
||||||
|
std = self.config.initializer_range
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, nn.Parameter):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
|
|
||||||
|
|
||||||
|
class PLMForTokenClassification(LlamaForTokenClassification):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PLMForCausalLM(LlamaForCausalLM):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PLMModel(LlamaModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PLMForSequenceClassification(LlamaForSequenceClassification):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PLMPreTrainedModel",
|
||||||
|
"PLMModel",
|
||||||
|
"PLMForCausalLM",
|
||||||
|
"PLMForSequenceClassification",
|
||||||
|
"PLMForTokenClassification",
|
||||||
|
]
|
0
tests/models/plm/__init__.py
Normal file
0
tests/models/plm/__init__.py
Normal file
516
tests/models/plm/test_modeling_plm.py
Normal file
516
tests/models/plm/test_modeling_plm.py
Normal file
@ -0,0 +1,516 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Testing suite for the PyTorch PLM model."""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, PLMConfig, is_torch_available
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
require_read_token,
|
||||||
|
require_torch,
|
||||||
|
require_torch_accelerator,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
PLMForCausalLM,
|
||||||
|
PLMForSequenceClassification,
|
||||||
|
PLMForTokenClassification,
|
||||||
|
PLMModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
class PLMModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=False,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
intermediate_size=37,
|
||||||
|
num_hidden_layers=5,
|
||||||
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=4,
|
||||||
|
kv_lora_rank=16,
|
||||||
|
q_lora_rank=32,
|
||||||
|
qk_rope_head_dim=16,
|
||||||
|
v_head_dim=32,
|
||||||
|
qk_nope_head_dim=32,
|
||||||
|
n_group=2,
|
||||||
|
first_k_dense_replace=2,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
hidden_act="relu2",
|
||||||
|
max_position_embeddings=512,
|
||||||
|
initializer_range=0.02,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
pad_token_id=0,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return PLMConfig(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
num_key_value_heads=self.num_key_value_heads,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
q_lora_rank=self.q_lora_rank,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
v_head_dim=self.v_head_dim,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=self.pad_token_id,
|
||||||
|
attention_dropout=self.attention_probs_dropout_prob,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
):
|
||||||
|
model = PLMModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape,
|
||||||
|
(self.batch_size, self.seq_length, self.hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model_as_decoder(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = PLMModel(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
result = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
)
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape,
|
||||||
|
(self.batch_size, self.seq_length, self.hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_for_causal_lm(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
model = PLMForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = PLMForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(
|
||||||
|
next_input_ids,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
output_from_past = model(
|
||||||
|
next_tokens,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
|
# breakpoint()
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
PLMModel,
|
||||||
|
PLMForCausalLM,
|
||||||
|
PLMForSequenceClassification,
|
||||||
|
PLMForTokenClassification,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
all_generative_model_classes = (PLMForCausalLM,) if is_torch_available() else ()
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"feature-extraction": PLMModel,
|
||||||
|
"text-classification": PLMForSequenceClassification,
|
||||||
|
"token-classification": PLMForTokenClassification,
|
||||||
|
"text-generation": PLMForCausalLM,
|
||||||
|
"zero-shot": PLMForSequenceClassification,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
fx_compatible = False
|
||||||
|
|
||||||
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||||
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
|
|
||||||
|
# used in `test_torch_compile_for_training`
|
||||||
|
_torch_compile_train_cls = PLMForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = PLMModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=PLMConfig, hidden_size=37)
|
||||||
|
|
||||||
|
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||||
|
def test_model_outputs_equivalence(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("random",), ("same",)])
|
||||||
|
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
|
||||||
|
def test_assisted_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache which is not compatible with dola decoding")
|
||||||
|
def test_dola_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache and doesn't support continue from past kv")
|
||||||
|
def test_generate_continue_from_past_key_values(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache and doesn't support low_memory generation")
|
||||||
|
def test_beam_search_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||||
|
def test_generate_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||||
|
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
|
||||||
|
def test_generate_continue_from_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM's eager attn/sdpa attn outputs are expected to be different")
|
||||||
|
def test_sdpa_equivalence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
|
||||||
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
|
||||||
|
def test_generate_compilation_all_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
|
||||||
|
def test_generate_compile_model_forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
|
||||||
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_PLM_token_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = PLMForTokenClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_PLM_sequence_classification_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||||
|
model = PLMForSequenceClassification(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_accelerator
|
||||||
|
class PLMIntegrationTest(unittest.TestCase):
|
||||||
|
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||||
|
# Depending on the hardware we get different logits / generations
|
||||||
|
cuda_compute_capability_major_version = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if is_torch_available() and torch.cuda.is_available():
|
||||||
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_accelerator
|
||||||
|
@require_read_token
|
||||||
|
def test_compile_static_cache(self):
|
||||||
|
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||||
|
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||||
|
|
||||||
|
NUM_TOKENS_TO_GENERATE = 40
|
||||||
|
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||||
|
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||||
|
EXPECTED_TEXT_COMPLETION = [
|
||||||
|
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||||
|
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||||
|
"theory of relativ",
|
||||||
|
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
||||||
|
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Simply put, the theory of relativity states that ",
|
||||||
|
"My favorite all time favorite condiment is ketchup.",
|
||||||
|
]
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base", use_fast=False)
|
||||||
|
model = PLMForCausalLM.from_pretrained(
|
||||||
|
"PLM-Team/PLM-1.8B-Base", device_map=torch_device, torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||||
|
|
||||||
|
# Dynamic Cache
|
||||||
|
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||||
|
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||||
|
|
||||||
|
# Static Cache
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||||
|
)
|
||||||
|
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||||
|
|
||||||
|
# Static Cache + compile
|
||||||
|
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
|
||||||
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||||
|
)
|
||||||
|
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||||
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
Loading…
Reference in New Issue
Block a user