mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add plm
This commit is contained in:
parent
6daa3eeba5
commit
ef97fe7e0e
@ -229,6 +229,7 @@ if TYPE_CHECKING:
|
||||
from .pix2struct import *
|
||||
from .pixtral import *
|
||||
from .plbart import *
|
||||
from .plm import *
|
||||
from .poolformer import *
|
||||
from .pop2piano import *
|
||||
from .prophetnet import *
|
||||
|
@ -250,6 +250,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("pix2struct", "Pix2StructConfig"),
|
||||
("pixtral", "PixtralVisionConfig"),
|
||||
("plbart", "PLBartConfig"),
|
||||
("plm", "PLMConfig"),
|
||||
("poolformer", "PoolFormerConfig"),
|
||||
("pop2piano", "Pop2PianoConfig"),
|
||||
("prompt_depth_anything", "PromptDepthAnythingConfig"),
|
||||
@ -621,6 +622,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("pix2struct", "Pix2Struct"),
|
||||
("pixtral", "Pixtral"),
|
||||
("plbart", "PLBart"),
|
||||
("plm", "PLM"),
|
||||
("poolformer", "PoolFormer"),
|
||||
("pop2piano", "Pop2Piano"),
|
||||
("prompt_depth_anything", "PromptDepthAnything"),
|
||||
|
@ -228,6 +228,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("phimoe", "PhimoeModel"),
|
||||
("pixtral", "PixtralVisionModel"),
|
||||
("plbart", "PLBartModel"),
|
||||
("plm", "PLMModel"),
|
||||
("poolformer", "PoolFormerModel"),
|
||||
("prophetnet", "ProphetNetModel"),
|
||||
("pvt", "PvtModel"),
|
||||
@ -587,6 +588,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("phi4_multimodal", "Phi4MultimodalForCausalLM"),
|
||||
("phimoe", "PhimoeForCausalLM"),
|
||||
("plbart", "PLBartForCausalLM"),
|
||||
("plm", "PLMForCausalLM"),
|
||||
("prophetnet", "ProphetNetForCausalLM"),
|
||||
("qdqbert", "QDQBertLMHeadModel"),
|
||||
("qwen2", "Qwen2ForCausalLM"),
|
||||
@ -1095,6 +1097,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("phi3", "Phi3ForSequenceClassification"),
|
||||
("phimoe", "PhimoeForSequenceClassification"),
|
||||
("plbart", "PLBartForSequenceClassification"),
|
||||
("plm", "PLMForSequenceClassification"),
|
||||
("qdqbert", "QDQBertForSequenceClassification"),
|
||||
("qwen2", "Qwen2ForSequenceClassification"),
|
||||
("qwen2_moe", "Qwen2MoeForSequenceClassification"),
|
||||
@ -1285,6 +1288,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("persimmon", "PersimmonForTokenClassification"),
|
||||
("phi", "PhiForTokenClassification"),
|
||||
("phi3", "Phi3ForTokenClassification"),
|
||||
("plm", "PLMForTokenClassification"),
|
||||
("qdqbert", "QDQBertForTokenClassification"),
|
||||
("qwen2", "Qwen2ForTokenClassification"),
|
||||
("qwen2_moe", "Qwen2MoeForTokenClassification"),
|
||||
|
27
src/transformers/models/plm/__init__.py
Normal file
27
src/transformers/models/plm/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_plm import *
|
||||
from .modeling_plm import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
150
src/transformers/models/plm/configuration_plm.py
Normal file
150
src/transformers/models/plm/configuration_plm.py
Normal file
@ -0,0 +1,150 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The PLM team and The HuggingFace Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""PLM model configuration"""
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PLMConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`PLMModel`]. It is used to instantiate a
|
||||
PLM model according to the specified arguments, defining the model architecture.
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the PLM model.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the PLM model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`PLMModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports normal rope.
|
||||
rope_theta (`float`, *optional*, defaults to 100000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
```python
|
||||
>>> from transformers import PLMModel, PLMConfig
|
||||
>>> # Initializing a PLM style configuration
|
||||
>>> configuration = PLMConfig()
|
||||
>>> # Initializing a model from the PLM style configuration
|
||||
>>> model = PLMModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "plm"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
kv_lora_rank = 512,
|
||||
q_lora_rank = None,
|
||||
qk_rope_head_dim = 64,
|
||||
v_head_dim = 128,
|
||||
qk_nope_head_dim = 128,
|
||||
hidden_act="relu2",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=True,
|
||||
rope_theta=100000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
rope_interleave=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.rope_interleave = rope_interleave
|
||||
self.head_dim = qk_rope_head_dim
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
__all__ = ["PLMConfig"]
|
1107
src/transformers/models/plm/modeling_plm.py
Normal file
1107
src/transformers/models/plm/modeling_plm.py
Normal file
File diff suppressed because it is too large
Load Diff
262
src/transformers/models/plm/modular_plm.py
Normal file
262
src/transformers/models/plm/modular_plm.py
Normal file
@ -0,0 +1,262 @@
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import logging
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaRotaryEmbedding,
|
||||
LlamaForTokenClassification,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
rotate_half,
|
||||
)
|
||||
from .configuration_plm import PLMConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
class PLMRMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class PLMRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
r"""
|
||||
TODO let's just use the original freqcis computation to not have the view
|
||||
transpose + reshape! This is not optimized!
|
||||
Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
|
||||
b, h, s, d = q.shape
|
||||
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
|
||||
b, h, s, d = k.shape
|
||||
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
class PLMMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_state):
|
||||
h = self.up_proj(hidden_state)
|
||||
h = self.act_fn(h)
|
||||
h = self.down_proj(h)
|
||||
return h
|
||||
|
||||
|
||||
|
||||
class PLMAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: PLMConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.qk_head_dim = config.qk_head_dim
|
||||
|
||||
self.is_causal = True
|
||||
if config.q_lora_rank is not None:
|
||||
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_layernorm = PLMRMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
|
||||
else:
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
|
||||
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
config.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = PLMRMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
self.scaling = self.qk_head_dim ** (-0.5)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
|
||||
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
|
||||
if self.q_lora_rank is not None:
|
||||
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
|
||||
else:
|
||||
q_states = self.q_proj(hidden_states).view(query_shape).transpose(1, 2)
|
||||
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
|
||||
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
if self.config.rope_interleave: # support using interleaved weights for efficiency
|
||||
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
|
||||
else:
|
||||
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
|
||||
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
|
||||
|
||||
query_states = torch.cat((q_pass, q_rot), dim=-1)
|
||||
key_states = torch.cat((k_pass, k_rot), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
||||
class PLMDecoderLayer(LlamaDecoderLayer, nn.Module):
|
||||
def __init__(self, config: PLMConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = PLMAttention(config, layer_idx)
|
||||
self.mlp = PLMMLP(config)
|
||||
self.input_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
class PLMPreTrainedModel(LlamaPreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.Parameter):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
|
||||
class PLMForTokenClassification(LlamaForTokenClassification):
|
||||
pass
|
||||
|
||||
class PLMForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
class PLMModel(LlamaModel):
|
||||
pass
|
||||
class PLMForSequenceClassification(LlamaForSequenceClassification):
|
||||
pass
|
||||
|
||||
__all__ = [
|
||||
"PLMPreTrainedModel",
|
||||
"PLMModel",
|
||||
"PLMForCausalLM",
|
||||
"PLMForSequenceClassification",
|
||||
"PLMForTokenClassification"
|
||||
]
|
0
tests/models/plm/__init__.py
Normal file
0
tests/models/plm/__init__.py
Normal file
597
tests/models/plm/test_modeling_plm.py
Normal file
597
tests/models/plm/test_modeling_plm.py
Normal file
@ -0,0 +1,597 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch PLM model."""
|
||||
|
||||
import unittest
|
||||
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, PLMConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
PLMForCausalLM,
|
||||
PLMModel,
|
||||
)
|
||||
from transformers.models.plm.modeling_plm import (
|
||||
PLMRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class PLMModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
kv_lora_rank=16,
|
||||
q_lora_rank=32,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
qk_nope_head_dim=32,
|
||||
n_group=2,
|
||||
first_k_dense_replace=2,
|
||||
norm_topk_prob=True,
|
||||
hidden_act="relu2",
|
||||
max_position_embeddings=512,
|
||||
initializer_range=0.02,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
# breakpoint()
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
return PLMConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_key_value_heads,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
hidden_act=self.hidden_act,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
initializer_range=self.initializer_range,
|
||||
use_cache=True,
|
||||
pad_token_id=self.pad_token_id,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = PLMModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = PLMModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = PLMForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = PLMForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@require_torch
|
||||
class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
# breakpoint()
|
||||
all_model_classes = (
|
||||
(
|
||||
PLMModel,
|
||||
PLMForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (PLMForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": PLMModel,
|
||||
"text-generation": PLMForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
# used in `test_torch_compile_for_training`
|
||||
_torch_compile_train_cls = PLMForCausalLM if is_torch_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = PLMModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PLMConfig, hidden_size=37)
|
||||
|
||||
@unittest.skip("Failing because of unique cache (HybridCache)")
|
||||
def test_model_outputs_equivalence(self, **kwargs):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("random",), ("same",)])
|
||||
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM has HybridCache which is not compatible with dola decoding")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM has HybridCache and doesn't support continue from past kv")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM has HybridCache and doesn't support low_memory generation")
|
||||
def test_beam_search_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
|
||||
)
|
||||
def test_generate_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
|
||||
)
|
||||
def test_generate_from_inputs_embeds_with_static_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
|
||||
)
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM's eager attn/sdpa attn outputs are expected to be different")
|
||||
def test_sdpa_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
|
||||
def test_generate_compile_model_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
# breakpoint()
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# def test_model_various_embeddings(self):
|
||||
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
# config_and_inputs[0].position_embedding_type = type
|
||||
# self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# @parameterized.expand([("yarn",)])
|
||||
# def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
# config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
# long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||
|
||||
# set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
# original_model = PLMModel(config)
|
||||
# original_model.to(torch_device)
|
||||
# original_model.eval()
|
||||
# original_short_output = original_model(short_input).last_hidden_state
|
||||
# original_long_output = original_model(long_input).last_hidden_state
|
||||
|
||||
# set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
# config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||
# scaled_model = PLMModel(config)
|
||||
# scaled_model.to(torch_device)
|
||||
# scaled_model.eval()
|
||||
# scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||
# scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||
|
||||
# # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||
# # maximum sequence length, so the outputs for the short input should match.
|
||||
# if scaling_type == "dynamic":
|
||||
# torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
||||
# else:
|
||||
# self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
|
||||
# # The output should be different for long inputs
|
||||
# self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
# def test_model_rope_scaling(self):
|
||||
# config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# scaling_factor = 10
|
||||
# short_input_length = 10
|
||||
# long_input_length = int(config.max_position_embeddings * 1.5)
|
||||
|
||||
# # Inputs
|
||||
# x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
# position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
# position_ids_short = position_ids_short.unsqueeze(0)
|
||||
# position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
# position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# # Sanity check original RoPE
|
||||
# original_rope = PLMRotaryEmbedding(config=config).to(torch_device)
|
||||
# original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
# original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
# torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
# torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
|
||||
# @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format")
|
||||
# def test_past_key_values_format(self):
|
||||
# pass
|
||||
|
||||
# @require_torch_sdpa
|
||||
# @slow
|
||||
# def test_eager_matches_sdpa_generate(self):
|
||||
# """
|
||||
# Overwritting the common test as the test is flaky on tiny models
|
||||
# """
|
||||
# max_new_tokens = 30
|
||||
|
||||
# tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base")
|
||||
|
||||
# model_sdpa = PLMForCausalLM.from_pretrained(
|
||||
# "PLM-Team/PLM-1.8B-Base",
|
||||
# torch_dtype=torch.float16,
|
||||
# low_cpu_mem_usage=True,
|
||||
# trust_remote_code=True,
|
||||
# ).to(torch_device)
|
||||
|
||||
# self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
# model_eager = PLMForCausalLM.from_pretrained(
|
||||
# "PLM-Team/PLM-1.8B-Base",
|
||||
# torch_dtype=torch.float16,
|
||||
# low_cpu_mem_usage=True,
|
||||
# attn_implementation="eager",
|
||||
# trust_remote_code=True,
|
||||
# ).to(torch_device)
|
||||
# breakpoint()
|
||||
# self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
# texts = [
|
||||
# "hi here's a longer context, getting longer and",
|
||||
# "Hello this is a very long sentence my friend, very long for real",
|
||||
# "Today I am in Paris and",
|
||||
# ]
|
||||
|
||||
# for padding_side in ["left", "right"]:
|
||||
# tokenizer.padding_side = padding_side
|
||||
# tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
# res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||
# res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
|
||||
|
||||
# with self.subTest(f"{padding_side}"):
|
||||
# torch.testing.assert_close(
|
||||
# res_eager,
|
||||
# res_sdpa,
|
||||
# msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
||||
# )
|
||||
|
||||
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class PLMIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmeth#od
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
|
||||
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
|
||||
if version.parse(torch.__version__) < version.parse("2.3.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.3 to run.")
|
||||
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
|
||||
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
|
||||
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
|
||||
"theory of relativ",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
|
||||
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
]
|
||||
|
||||
prompts = [
|
||||
"Simply put, the theory of relativity states that ",
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base", use_fast=False)
|
||||
model = PLMForCausalLM.from_pretrained(
|
||||
"PLM-Team/PLM-1.8B-Base", device_map=torch_device, torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
# Dynamic Cache
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
|
||||
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
|
||||
|
||||
# Static Cache
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
|
||||
|
||||
# Static Cache + compile
|
||||
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
|
Loading…
Reference in New Issue
Block a user