From ef97fe7e0e12c7f2b9170c217fbda9458fd0ace7 Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 05:28:48 +0000 Subject: [PATCH 01/17] add plm --- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 4 + src/transformers/models/plm/__init__.py | 27 + .../models/plm/configuration_plm.py | 150 +++ src/transformers/models/plm/modeling_plm.py | 1107 +++++++++++++++++ src/transformers/models/plm/modular_plm.py | 262 ++++ tests/models/plm/__init__.py | 0 tests/models/plm/test_modeling_plm.py | 597 +++++++++ 9 files changed, 2150 insertions(+) create mode 100644 src/transformers/models/plm/__init__.py create mode 100644 src/transformers/models/plm/configuration_plm.py create mode 100644 src/transformers/models/plm/modeling_plm.py create mode 100644 src/transformers/models/plm/modular_plm.py create mode 100644 tests/models/plm/__init__.py create mode 100644 tests/models/plm/test_modeling_plm.py diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 26e6e0a9799..9ed43f0ff30 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -229,6 +229,7 @@ if TYPE_CHECKING: from .pix2struct import * from .pixtral import * from .plbart import * + from .plm import * from .poolformer import * from .pop2piano import * from .prophetnet import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7d13bc788d4..eaf278732e8 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -250,6 +250,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("pix2struct", "Pix2StructConfig"), ("pixtral", "PixtralVisionConfig"), ("plbart", "PLBartConfig"), + ("plm", "PLMConfig"), ("poolformer", "PoolFormerConfig"), ("pop2piano", "Pop2PianoConfig"), ("prompt_depth_anything", "PromptDepthAnythingConfig"), @@ -621,6 +622,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("pix2struct", "Pix2Struct"), ("pixtral", "Pixtral"), ("plbart", "PLBart"), + ("plm", "PLM"), ("poolformer", "PoolFormer"), ("pop2piano", "Pop2Piano"), ("prompt_depth_anything", "PromptDepthAnything"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a7271d04f60..5199d757306 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -228,6 +228,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("phimoe", "PhimoeModel"), ("pixtral", "PixtralVisionModel"), ("plbart", "PLBartModel"), + ("plm", "PLMModel"), ("poolformer", "PoolFormerModel"), ("prophetnet", "ProphetNetModel"), ("pvt", "PvtModel"), @@ -587,6 +588,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("phi4_multimodal", "Phi4MultimodalForCausalLM"), ("phimoe", "PhimoeForCausalLM"), ("plbart", "PLBartForCausalLM"), + ("plm", "PLMForCausalLM"), ("prophetnet", "ProphetNetForCausalLM"), ("qdqbert", "QDQBertLMHeadModel"), ("qwen2", "Qwen2ForCausalLM"), @@ -1095,6 +1097,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("phi3", "Phi3ForSequenceClassification"), ("phimoe", "PhimoeForSequenceClassification"), ("plbart", "PLBartForSequenceClassification"), + ("plm", "PLMForSequenceClassification"), ("qdqbert", "QDQBertForSequenceClassification"), ("qwen2", "Qwen2ForSequenceClassification"), ("qwen2_moe", "Qwen2MoeForSequenceClassification"), @@ -1285,6 +1288,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ("persimmon", "PersimmonForTokenClassification"), ("phi", "PhiForTokenClassification"), ("phi3", "Phi3ForTokenClassification"), + ("plm", "PLMForTokenClassification"), ("qdqbert", "QDQBertForTokenClassification"), ("qwen2", "Qwen2ForTokenClassification"), ("qwen2_moe", "Qwen2MoeForTokenClassification"), diff --git a/src/transformers/models/plm/__init__.py b/src/transformers/models/plm/__init__.py new file mode 100644 index 00000000000..389e9e39abc --- /dev/null +++ b/src/transformers/models/plm/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_plm import * + from .modeling_plm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/plm/configuration_plm.py b/src/transformers/models/plm/configuration_plm.py new file mode 100644 index 00000000000..92956103b56 --- /dev/null +++ b/src/transformers/models/plm/configuration_plm.py @@ -0,0 +1,150 @@ +# coding=utf-8 +# Copyright 2025 The PLM team and The HuggingFace Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PLM model configuration""" +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class PLMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PLMModel`]. It is used to instantiate a + PLM model according to the specified arguments, defining the model architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the + defaults will yield a similar configuration to that of the PLM model. + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the PLM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PLMModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports normal rope. + rope_theta (`float`, *optional*, defaults to 100000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import PLMModel, PLMConfig + >>> # Initializing a PLM style configuration + >>> configuration = PLMConfig() + >>> # Initializing a model from the PLM style configuration + >>> model = PLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "plm" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=16, + num_key_value_heads=16, + kv_lora_rank = 512, + q_lora_rank = None, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + hidden_act="relu2", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pretraining_tp=1, + tie_word_embeddings=True, + rope_theta=100000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rope_interleave=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.rope_interleave = rope_interleave + self.head_dim = qk_rope_head_dim + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +__all__ = ["PLMConfig"] \ No newline at end of file diff --git a/src/transformers/models/plm/modeling_plm.py b/src/transformers/models/plm/modeling_plm.py new file mode 100644 index 00000000000..0514633f2e3 --- /dev/null +++ b/src/transformers/models/plm/modeling_plm.py @@ -0,0 +1,1107 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/plm/modular_plm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_plm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_plm import PLMConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "meta-plm/PLM-2-7b-hf" +_CONFIG_FOR_DOC = "PLMConfig" + + +@use_kernel_forward_from_hub("RMSNorm") +class PLMRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + PLMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class PLMRotaryEmbedding(nn.Module): + def __init__(self, config: PLMConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class PLMMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + h = self.up_proj(hidden_state) + h = self.act_fn(h) + h = self.down_proj(h) + return h + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + TODO let's just use the original freqcis computation to not have the view + transpose + reshape! This is not optimized! + Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class PLMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PLMConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if config.q_lora_rank is not None: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = PLMRMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = PLMRMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + if self.q_lora_rank is not None: + q_states = ( + self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) + ) + else: + q_states = self.q_proj(hidden_states).view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class PLMDecoderLayer(nn.Module): + def __init__(self, config: PLMConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PLMAttention(config, layer_idx) + self.mlp = PLMMLP(config) + self.input_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +PLM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PLMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare PLM Model outputting raw hidden-states without any specific head on top.", + PLM_START_DOCSTRING, +) +class PLMPreTrainedModel(PreTrainedModel): + config_class = PLMConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PLMDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Parameter): + module.weight.data.normal_(mean=0.0, std=std) + + +PLM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """ + The PLM Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + PLM_START_DOCSTRING, +) +class PLMForTokenClassification(PLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PLMModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class PLMForCausalLM(PLMPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = PLMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(PLM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PLMForCausalLM + + >>> model = PLMForCausalLM.from_pretrained("meta-plm/PLM-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-plm/PLM-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare PLM Model outputting raw hidden-states without any specific head on top.", + PLM_START_DOCSTRING, +) +class PLMModel(PLMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PLMDecoderLayer`] + + Args: + config: PLMConfig + """ + + def __init__(self, config: PLMConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [PLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = PLMRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PLM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +@add_start_docstrings( + """ + The PLM Model transformer with a sequence classification head on top (linear layer). + + [`PLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PLM_START_DOCSTRING, +) +class PLMForSequenceClassification(PLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PLMModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(PLM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +__all__ = [ + "PLMPreTrainedModel", + "PLMModel", + "PLMForCausalLM", + "PLMForSequenceClassification", + "PLMForTokenClassification", +] diff --git a/src/transformers/models/plm/modular_plm.py b/src/transformers/models/plm/modular_plm.py new file mode 100644 index 00000000000..4d5592afe9c --- /dev/null +++ b/src/transformers/models/plm/modular_plm.py @@ -0,0 +1,262 @@ +import math +from typing import Callable, Optional, Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaForSequenceClassification, + LlamaRotaryEmbedding, + LlamaForTokenClassification, + apply_rotary_pos_emb, + eager_attention_forward, + rotate_half, +) +from .configuration_plm import PLMConfig + + +logger = logging.get_logger(__name__) + +class PLMRMSNorm(LlamaRMSNorm): + pass + + +class PLMRotaryEmbedding(LlamaRotaryEmbedding): + pass + + + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + TODO let's just use the original freqcis computation to not have the view + transpose + reshape! This is not optimized! + Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + + +class PLMMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + h = self.up_proj(hidden_state) + h = self.act_fn(h) + h = self.down_proj(h) + return h + + + +class PLMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PLMConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if config.q_lora_rank is not None: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = PLMRMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = PLMRMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + if self.q_lora_rank is not None: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) + else: + q_states = self.q_proj(hidden_states).view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + + +class PLMDecoderLayer(LlamaDecoderLayer, nn.Module): + def __init__(self, config: PLMConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PLMAttention(config, layer_idx) + self.mlp = PLMMLP(config) + self.input_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class PLMPreTrainedModel(LlamaPreTrainedModel): + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Parameter): + module.weight.data.normal_(mean=0.0, std=std) + +class PLMForTokenClassification(LlamaForTokenClassification): + pass + +class PLMForCausalLM(LlamaForCausalLM): + pass + +class PLMModel(LlamaModel): + pass +class PLMForSequenceClassification(LlamaForSequenceClassification): + pass + +__all__ = [ + "PLMPreTrainedModel", + "PLMModel", + "PLMForCausalLM", + "PLMForSequenceClassification", + "PLMForTokenClassification" +] \ No newline at end of file diff --git a/tests/models/plm/__init__.py b/tests/models/plm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py new file mode 100644 index 00000000000..9fb377c87f5 --- /dev/null +++ b/tests/models/plm/test_modeling_plm.py @@ -0,0 +1,597 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch PLM model.""" + +import unittest + +from packaging import version +from parameterized import parameterized + +from transformers import AutoTokenizer, PLMConfig, is_torch_available, set_seed +from transformers.testing_utils import ( + require_read_token, + require_torch, + require_torch_accelerator, + require_torch_sdpa, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + PLMForCausalLM, + PLMModel, + ) + from transformers.models.plm.modeling_plm import ( + PLMRotaryEmbedding, + ) + + + + +class PLMModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + intermediate_size=37, + num_hidden_layers=5, + num_attention_heads=4, + num_key_value_heads=4, + kv_lora_rank=16, + q_lora_rank=32, + qk_rope_head_dim=16, + v_head_dim=32, + qk_nope_head_dim=32, + n_group=2, + first_k_dense_replace=2, + norm_topk_prob=True, + hidden_act="relu2", + max_position_embeddings=512, + initializer_range=0.02, + attention_probs_dropout_prob=0.1, + type_vocab_size=16, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + # breakpoint() + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return PLMConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + kv_lora_rank=self.kv_lora_rank, + q_lora_rank=self.q_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + qk_nope_head_dim=self.qk_nope_head_dim, + hidden_act=self.hidden_act, + max_position_embeddings=self.max_position_embeddings, + initializer_range=self.initializer_range, + use_cache=True, + pad_token_id=self.pad_token_id, + attention_dropout=self.attention_probs_dropout_prob, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = PLMModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = PLMModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = PLMForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = PLMForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + + + + + +@require_torch +class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + # breakpoint() + all_model_classes = ( + ( + PLMModel, + PLMForCausalLM, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (PLMForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": PLMModel, + "text-generation": PLMForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + # used in `test_torch_compile_for_training` + _torch_compile_train_cls = PLMForCausalLM if is_torch_available() else None + + def setUp(self): + self.model_tester = PLMModelTester(self) + self.config_tester = ConfigTester(self, config_class=PLMConfig, hidden_size=37) + + @unittest.skip("Failing because of unique cache (HybridCache)") + def test_model_outputs_equivalence(self, **kwargs): + pass + + @parameterized.expand([("random",), ("same",)]) + @unittest.skip("PLM has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("PLM has HybridCache which is not compatible with assisted decoding") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("PLM has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("PLM has HybridCache which is not compatible with dola decoding") + def test_dola_decoding_sample(self): + pass + + @unittest.skip("PLM has HybridCache and doesn't support continue from past kv") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("PLM has HybridCache and doesn't support low_memory generation") + def test_beam_search_low_memory(self): + pass + + @unittest.skip("PLM has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate(self): + pass + + @unittest.skip("PLM has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("PLM has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip( + "PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." + ) + def test_generate_with_static_cache(self): + pass + + @unittest.skip( + "PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." + ) + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip( + "PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." + ) + def test_generate_continue_from_inputs_embeds(self): + pass + + @unittest.skip("PLM's eager attn/sdpa attn outputs are expected to be different") + def test_sdpa_equivalence(self): + pass + + @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") + def test_generate_compilation_all_outputs(self): + pass + + @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") + def test_generate_compile_model_forward(self): + pass + + @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") + def test_greedy_generate_dict_outputs_use_cache(self): + pass + # breakpoint() + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + # config_and_inputs = self.model_tester.prepare_config_and_inputs() + # self.model_tester.create_and_check_model(*config_and_inputs) + + # def test_model_various_embeddings(self): + # config_and_inputs = self.model_tester.prepare_config_and_inputs() + # for type in ["absolute", "relative_key", "relative_key_query"]: + # config_and_inputs[0].position_embedding_type = type + # self.model_tester.create_and_check_model(*config_and_inputs) + + # @parameterized.expand([("yarn",)]) + # def test_model_rope_scaling_from_config(self, scaling_type): + # config, _ = self.model_tester.prepare_config_and_inputs_for_common() + # short_input = ids_tensor([1, 10], config.vocab_size) + # long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + # set_seed(42) # Fixed seed at init time so the two models get the same random weights + # original_model = PLMModel(config) + # original_model.to(torch_device) + # original_model.eval() + # original_short_output = original_model(short_input).last_hidden_state + # original_long_output = original_model(long_input).last_hidden_state + + # set_seed(42) # Fixed seed at init time so the two models get the same random weights + # config.rope_scaling = {"type": scaling_type, "factor": 10.0} + # scaled_model = PLMModel(config) + # scaled_model.to(torch_device) + # scaled_model.eval() + # scaled_short_output = scaled_model(short_input).last_hidden_state + # scaled_long_output = scaled_model(long_input).last_hidden_state + + # # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # # maximum sequence length, so the outputs for the short input should match. + # if scaling_type == "dynamic": + # torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) + # else: + # self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # # The output should be different for long inputs + # self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + # def test_model_rope_scaling(self): + # config, _ = self.model_tester.prepare_config_and_inputs_for_common() + # scaling_factor = 10 + # short_input_length = 10 + # long_input_length = int(config.max_position_embeddings * 1.5) + + # # Inputs + # x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + # position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + # position_ids_short = position_ids_short.unsqueeze(0) + # position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + # position_ids_long = position_ids_long.unsqueeze(0) + + # # Sanity check original RoPE + # original_rope = PLMRotaryEmbedding(config=config).to(torch_device) + # original_cos_short, original_sin_short = original_rope(x, position_ids_short) + # original_cos_long, original_sin_long = original_rope(x, position_ids_long) + # torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + # torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + + # @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format") + # def test_past_key_values_format(self): + # pass + + # @require_torch_sdpa + # @slow + # def test_eager_matches_sdpa_generate(self): + # """ + # Overwritting the common test as the test is flaky on tiny models + # """ + # max_new_tokens = 30 + + # tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base") + + # model_sdpa = PLMForCausalLM.from_pretrained( + # "PLM-Team/PLM-1.8B-Base", + # torch_dtype=torch.float16, + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # ).to(torch_device) + + # self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + # model_eager = PLMForCausalLM.from_pretrained( + # "PLM-Team/PLM-1.8B-Base", + # torch_dtype=torch.float16, + # low_cpu_mem_usage=True, + # attn_implementation="eager", + # trust_remote_code=True, + # ).to(torch_device) + # breakpoint() + # self.assertTrue(model_eager.config._attn_implementation == "eager") + + # texts = [ + # "hi here's a longer context, getting longer and", + # "Hello this is a very long sentence my friend, very long for real", + # "Today I am in Paris and", + # ] + + # for padding_side in ["left", "right"]: + # tokenizer.padding_side = padding_side + # tokenizer.pad_token = tokenizer.eos_token + + # inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) + + # res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + # res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + + # with self.subTest(f"{padding_side}"): + # torch.testing.assert_close( + # res_eager, + # res_sdpa, + # msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", + # ) + + + + +@require_torch_accelerator +class PLMIntegrationTest(unittest.TestCase): + # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # Depending on the hardware we get different logits / generations + cuda_compute_capability_major_version = None + + @classmeth#od + def setUpClass(cls): + if is_torch_available() and torch.cuda.is_available(): + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + + @slow + @require_torch_accelerator + @require_read_token + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ] + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base", use_fast=False) + model = PLMForCausalLM.from_pretrained( + "PLM-Team/PLM-1.8B-Base", device_map=torch_device, torch_dtype=torch.float16 + ) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"` + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) From b525fba43f04609e32ed87c0ef60df03bd812e5c Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 05:35:07 +0000 Subject: [PATCH 02/17] test plm model --- tests/models/plm/test_modeling_plm.py | 188 +++++++++++++------------- 1 file changed, 94 insertions(+), 94 deletions(-) diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index 9fb377c87f5..31dbe31247f 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -415,121 +415,121 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, self.config_tester.run_common_tests() def test_model(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # self.model_tester.create_and_check_model(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) - # def test_model_various_embeddings(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # for type in ["absolute", "relative_key", "relative_key_query"]: - # config_and_inputs[0].position_embedding_type = type - # self.model_tester.create_and_check_model(*config_and_inputs) + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) - # @parameterized.expand([("yarn",)]) - # def test_model_rope_scaling_from_config(self, scaling_type): - # config, _ = self.model_tester.prepare_config_and_inputs_for_common() - # short_input = ids_tensor([1, 10], config.vocab_size) - # long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + @parameterized.expand([("yarn",)]) + def test_model_rope_scaling_from_config(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - # set_seed(42) # Fixed seed at init time so the two models get the same random weights - # original_model = PLMModel(config) - # original_model.to(torch_device) - # original_model.eval() - # original_short_output = original_model(short_input).last_hidden_state - # original_long_output = original_model(long_input).last_hidden_state + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = PLMModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state - # set_seed(42) # Fixed seed at init time so the two models get the same random weights - # config.rope_scaling = {"type": scaling_type, "factor": 10.0} - # scaled_model = PLMModel(config) - # scaled_model.to(torch_device) - # scaled_model.eval() - # scaled_short_output = scaled_model(short_input).last_hidden_state - # scaled_long_output = scaled_model(long_input).last_hidden_state + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = PLMModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state - # # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # # maximum sequence length, so the outputs for the short input should match. - # if scaling_type == "dynamic": - # torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - # else: - # self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - # # The output should be different for long inputs - # self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # def test_model_rope_scaling(self): - # config, _ = self.model_tester.prepare_config_and_inputs_for_common() - # scaling_factor = 10 - # short_input_length = 10 - # long_input_length = int(config.max_position_embeddings * 1.5) + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) - # # Inputs - # x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device - # position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - # position_ids_short = position_ids_short.unsqueeze(0) - # position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - # position_ids_long = position_ids_long.unsqueeze(0) + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + position_ids_short = position_ids_short.unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + position_ids_long = position_ids_long.unsqueeze(0) - # # Sanity check original RoPE - # original_rope = PLMRotaryEmbedding(config=config).to(torch_device) - # original_cos_short, original_sin_short = original_rope(x, position_ids_short) - # original_cos_long, original_sin_long = original_rope(x, position_ids_long) - # torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - # torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + # Sanity check original RoPE + original_rope = PLMRotaryEmbedding(config=config).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short) + original_cos_long, original_sin_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - # @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format") - # def test_past_key_values_format(self): - # pass + @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format") + def test_past_key_values_format(self): + pass - # @require_torch_sdpa - # @slow - # def test_eager_matches_sdpa_generate(self): - # """ - # Overwritting the common test as the test is flaky on tiny models - # """ - # max_new_tokens = 30 + @require_torch_sdpa + @slow + def test_eager_matches_sdpa_generate(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + max_new_tokens = 30 - # tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base") + tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base") - # model_sdpa = PLMForCausalLM.from_pretrained( - # "PLM-Team/PLM-1.8B-Base", - # torch_dtype=torch.float16, - # low_cpu_mem_usage=True, - # trust_remote_code=True, - # ).to(torch_device) + model_sdpa = PLMForCausalLM.from_pretrained( + "PLM-Team/PLM-1.8B-Base", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).to(torch_device) - # self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - # model_eager = PLMForCausalLM.from_pretrained( - # "PLM-Team/PLM-1.8B-Base", - # torch_dtype=torch.float16, - # low_cpu_mem_usage=True, - # attn_implementation="eager", - # trust_remote_code=True, - # ).to(torch_device) - # breakpoint() - # self.assertTrue(model_eager.config._attn_implementation == "eager") + model_eager = PLMForCausalLM.from_pretrained( + "PLM-Team/PLM-1.8B-Base", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation="eager", + trust_remote_code=True, + ).to(torch_device) + breakpoint() + self.assertTrue(model_eager.config._attn_implementation == "eager") - # texts = [ - # "hi here's a longer context, getting longer and", - # "Hello this is a very long sentence my friend, very long for real", - # "Today I am in Paris and", - # ] + texts = [ + "hi here's a longer context, getting longer and", + "Hello this is a very long sentence my friend, very long for real", + "Today I am in Paris and", + ] - # for padding_side in ["left", "right"]: - # tokenizer.padding_side = padding_side - # tokenizer.pad_token = tokenizer.eos_token + for padding_side in ["left", "right"]: + tokenizer.padding_side = padding_side + tokenizer.pad_token = tokenizer.eos_token - # inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) - # res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - # res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - # with self.subTest(f"{padding_side}"): - # torch.testing.assert_close( - # res_eager, - # res_sdpa, - # msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", - # ) + with self.subTest(f"{padding_side}"): + torch.testing.assert_close( + res_eager, + res_sdpa, + msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", + ) From 1ec7ee8280195185c875a61b85415a1aec17730b Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 06:08:40 +0000 Subject: [PATCH 03/17] test plm model --- tests/models/plm/test_modeling_plm.py | 190 +++++++++++++------------- 1 file changed, 95 insertions(+), 95 deletions(-) diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index 31dbe31247f..3e14e903dbe 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -414,122 +414,122 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_config(self): self.config_tester.run_common_tests() - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) + # def test_model(self): + # config_and_inputs = self.model_tester.prepare_config_and_inputs() + # self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) + # def test_model_various_embeddings(self): + # config_and_inputs = self.model_tester.prepare_config_and_inputs() + # for type in ["absolute", "relative_key", "relative_key_query"]: + # config_and_inputs[0].position_embedding_type = type + # self.model_tester.create_and_check_model(*config_and_inputs) - @parameterized.expand([("yarn",)]) - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + # @parameterized.expand([("yarn",)]) + # def test_model_rope_scaling_from_config(self, scaling_type): + # config, _ = self.model_tester.prepare_config_and_inputs_for_common() + # short_input = ids_tensor([1, 10], config.vocab_size) + # long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = PLMModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state + # set_seed(42) # Fixed seed at init time so the two models get the same random weights + # original_model = PLMModel(config) + # original_model.to(torch_device) + # original_model.eval() + # original_short_output = original_model(short_input).last_hidden_state + # original_long_output = original_model(long_input).last_hidden_state - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = PLMModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state + # set_seed(42) # Fixed seed at init time so the two models get the same random weights + # config.rope_scaling = {"type": scaling_type, "factor": 10.0} + # scaled_model = PLMModel(config) + # scaled_model.to(torch_device) + # scaled_model.eval() + # scaled_short_output = scaled_model(short_input).last_hidden_state + # scaled_long_output = scaled_model(long_input).last_hidden_state - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + # # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # # maximum sequence length, so the outputs for the short input should match. + # if scaling_type == "dynamic": + # torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) + # else: + # self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # # The output should be different for long inputs + # self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) + # def test_model_rope_scaling(self): + # config, _ = self.model_tester.prepare_config_and_inputs_for_common() + # scaling_factor = 10 + # short_input_length = 10 + # long_input_length = int(config.max_position_embeddings * 1.5) - # Inputs - x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) + # # Inputs + # x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + # position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + # position_ids_short = position_ids_short.unsqueeze(0) + # position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + # position_ids_long = position_ids_long.unsqueeze(0) - # Sanity check original RoPE - original_rope = PLMRotaryEmbedding(config=config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + # # Sanity check original RoPE + # original_rope = PLMRotaryEmbedding(config=config).to(torch_device) + # original_cos_short, original_sin_short = original_rope(x, position_ids_short) + # original_cos_long, original_sin_long = original_rope(x, position_ids_long) + # torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + # torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format") - def test_past_key_values_format(self): - pass + # @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format") + # def test_past_key_values_format(self): + # pass - @require_torch_sdpa - @slow - def test_eager_matches_sdpa_generate(self): - """ - Overwritting the common test as the test is flaky on tiny models - """ - max_new_tokens = 30 + # @require_torch_sdpa + # @slow + # def test_eager_matches_sdpa_generate(self): + # """ + # Overwritting the common test as the test is flaky on tiny models + # """ + # max_new_tokens = 30 - tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base") + # tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base") - model_sdpa = PLMForCausalLM.from_pretrained( - "PLM-Team/PLM-1.8B-Base", - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - trust_remote_code=True, - ).to(torch_device) + # model_sdpa = PLMForCausalLM.from_pretrained( + # "PLM-Team/PLM-1.8B-Base", + # torch_dtype=torch.float16, + # low_cpu_mem_usage=True, + # trust_remote_code=True, + # ).to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + # self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - model_eager = PLMForCausalLM.from_pretrained( - "PLM-Team/PLM-1.8B-Base", - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - attn_implementation="eager", - trust_remote_code=True, - ).to(torch_device) - breakpoint() - self.assertTrue(model_eager.config._attn_implementation == "eager") + # model_eager = PLMForCausalLM.from_pretrained( + # "PLM-Team/PLM-1.8B-Base", + # torch_dtype=torch.float16, + # low_cpu_mem_usage=True, + # attn_implementation="eager", + # trust_remote_code=True, + # ).to(torch_device) + # breakpoint() + # self.assertTrue(model_eager.config._attn_implementation == "eager") - texts = [ - "hi here's a longer context, getting longer and", - "Hello this is a very long sentence my friend, very long for real", - "Today I am in Paris and", - ] + # texts = [ + # "hi here's a longer context, getting longer and", + # "Hello this is a very long sentence my friend, very long for real", + # "Today I am in Paris and", + # ] - for padding_side in ["left", "right"]: - tokenizer.padding_side = padding_side - tokenizer.pad_token = tokenizer.eos_token + # for padding_side in ["left", "right"]: + # tokenizer.padding_side = padding_side + # tokenizer.pad_token = tokenizer.eos_token - inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) + # inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) - res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + # res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + # res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - with self.subTest(f"{padding_side}"): - torch.testing.assert_close( - res_eager, - res_sdpa, - msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", - ) + # with self.subTest(f"{padding_side}"): + # torch.testing.assert_close( + # res_eager, + # res_sdpa, + # msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", + # ) From 1e9e950e3591c54b622378b38802c1b78efb0856 Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 06:12:25 +0000 Subject: [PATCH 04/17] test plm model --- tests/models/plm/test_modeling_plm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index 3e14e903dbe..d28d6c15b04 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -410,7 +410,7 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") def test_greedy_generate_dict_outputs_use_cache(self): pass - # breakpoint() + def test_config(self): self.config_tester.run_common_tests() @@ -540,7 +540,7 @@ class PLMIntegrationTest(unittest.TestCase): # Depending on the hardware we get different logits / generations cuda_compute_capability_major_version = None - @classmeth#od + @classmethod def setUpClass(cls): if is_torch_available() and torch.cuda.is_available(): cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] From 25cd37ab4ea2032f3e5c22457b4d9a3c9a96469e Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 06:25:58 +0000 Subject: [PATCH 05/17] test code format --- src/transformers/models/plm/configuration_plm.py | 4 ++-- src/transformers/models/plm/modular_plm.py | 12 +++++------- tests/models/plm/test_modeling_plm.py | 14 ++++---------- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/plm/configuration_plm.py b/src/transformers/models/plm/configuration_plm.py index 92956103b56..659f7448b0d 100644 --- a/src/transformers/models/plm/configuration_plm.py +++ b/src/transformers/models/plm/configuration_plm.py @@ -24,7 +24,7 @@ logger = logging.get_logger(__name__) class PLMConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`PLMModel`]. It is used to instantiate a - PLM model according to the specified arguments, defining the model architecture. + PLM model according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will yield a similar configuration to that of the PLM model. @@ -147,4 +147,4 @@ class PLMConfig(PretrainedConfig): **kwargs, ) -__all__ = ["PLMConfig"] \ No newline at end of file +__all__ = ["PLMConfig"] diff --git a/src/transformers/models/plm/modular_plm.py b/src/transformers/models/plm/modular_plm.py index 4d5592afe9c..250c0b12c77 100644 --- a/src/transformers/models/plm/modular_plm.py +++ b/src/transformers/models/plm/modular_plm.py @@ -1,4 +1,3 @@ -import math from typing import Callable, Optional, Tuple import torch @@ -15,12 +14,12 @@ from ...utils import logging from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, LlamaModel, LlamaPreTrainedModel, LlamaRMSNorm, - LlamaForSequenceClassification, LlamaRotaryEmbedding, - LlamaForTokenClassification, apply_rotary_pos_emb, eager_attention_forward, rotate_half, @@ -87,7 +86,7 @@ class PLMMLP(nn.Module): self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - + def forward(self, hidden_state): h = self.up_proj(hidden_state) h = self.act_fn(h) @@ -121,7 +120,7 @@ class PLMAttention(nn.Module): self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) else: self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) - + self.kv_a_proj_with_mqa = nn.Linear( config.hidden_size, @@ -215,7 +214,6 @@ class PLMAttention(nn.Module): attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights - class PLMDecoderLayer(LlamaDecoderLayer, nn.Module): @@ -259,4 +257,4 @@ __all__ = [ "PLMForCausalLM", "PLMForSequenceClassification", "PLMForTokenClassification" -] \ No newline at end of file +] diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index d28d6c15b04..f0626648617 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -19,12 +19,11 @@ import unittest from packaging import version from parameterized import parameterized -from transformers import AutoTokenizer, PLMConfig, is_torch_available, set_seed +from transformers import AutoTokenizer, PLMConfig, is_torch_available from transformers.testing_utils import ( require_read_token, require_torch, require_torch_accelerator, - require_torch_sdpa, slow, torch_device, ) @@ -42,9 +41,9 @@ if is_torch_available(): PLMForCausalLM, PLMModel, ) - from transformers.models.plm.modeling_plm import ( - PLMRotaryEmbedding, - ) + # from transformers.models.plm.modeling_plm import ( + # PLMRotaryEmbedding, + # ) @@ -291,10 +290,6 @@ class PLMModelTester: ) = config_and_inputs inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} return config, inputs_dict - - - - @require_torch @@ -498,7 +493,6 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # ).to(torch_device) # self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - # model_eager = PLMForCausalLM.from_pretrained( # "PLM-Team/PLM-1.8B-Base", # torch_dtype=torch.float16, From 645fc36db381e9393a3720a169dbb414c5ae5fe2 Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 06:38:34 +0000 Subject: [PATCH 06/17] reformat the code --- .../models/plm/configuration_plm.py | 11 +- src/transformers/models/plm/modular_plm.py | 90 ++++++++++---- tests/models/plm/test_modeling_plm.py | 111 +++++++++++++----- 3 files changed, 155 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/plm/configuration_plm.py b/src/transformers/models/plm/configuration_plm.py index 659f7448b0d..756a9184ebc 100644 --- a/src/transformers/models/plm/configuration_plm.py +++ b/src/transformers/models/plm/configuration_plm.py @@ -94,11 +94,11 @@ class PLMConfig(PretrainedConfig): num_hidden_layers=32, num_attention_heads=16, num_key_value_heads=16, - kv_lora_rank = 512, - q_lora_rank = None, - qk_rope_head_dim = 64, - v_head_dim = 128, - qk_nope_head_dim = 128, + kv_lora_rank=512, + q_lora_rank=None, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, hidden_act="relu2", max_position_embeddings=4096, initializer_range=0.02, @@ -147,4 +147,5 @@ class PLMConfig(PretrainedConfig): **kwargs, ) + __all__ = ["PLMConfig"] diff --git a/src/transformers/models/plm/modular_plm.py b/src/transformers/models/plm/modular_plm.py index 250c0b12c77..bfe333ec922 100644 --- a/src/transformers/models/plm/modular_plm.py +++ b/src/transformers/models/plm/modular_plm.py @@ -29,6 +29,7 @@ from .configuration_plm import PLMConfig logger = logging.get_logger(__name__) + class PLMRMSNorm(LlamaRMSNorm): pass @@ -37,8 +38,6 @@ class PLMRotaryEmbedding(LlamaRotaryEmbedding): pass - - def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): r""" TODO let's just use the original freqcis computation to not have the view @@ -77,7 +76,6 @@ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze return q_embed, k_embed - class PLMMLP(nn.Module): def __init__(self, config): super().__init__() @@ -94,7 +92,6 @@ class PLMMLP(nn.Module): return h - class PLMAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -102,7 +99,9 @@ class PLMAttention(nn.Module): super().__init__() self.config = config self.layer_idx = layer_idx - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) self.attention_dropout = config.attention_dropout self.num_heads = config.num_attention_heads self.rope_theta = config.rope_theta @@ -115,12 +114,17 @@ class PLMAttention(nn.Module): self.is_causal = True if config.q_lora_rank is not None: - self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_proj = nn.Linear( + config.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) self.q_a_layernorm = PLMRMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False + ) else: - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) - + self.q_proj = nn.Linear( + config.hidden_size, self.num_heads * self.qk_head_dim, bias=False + ) self.kv_a_proj_with_mqa = nn.Linear( config.hidden_size, @@ -142,7 +146,6 @@ class PLMAttention(nn.Module): self.scaling = self.qk_head_dim ** (-0.5) - def forward( self, hidden_states: torch.Tensor, @@ -154,23 +157,42 @@ class PLMAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: batch_size, seq_length = hidden_states.shape[:-1] query_shape = (batch_size, seq_length, -1, self.qk_head_dim) - key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + key_shape = ( + batch_size, + seq_length, + -1, + self.qk_nope_head_dim + self.v_head_dim, + ) if self.q_lora_rank is not None: - q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) + q_states = ( + self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + .view(query_shape) + .transpose(1, 2) + ) else: q_states = self.q_proj(hidden_states).view(query_shape).transpose(1, 2) - q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_pass, q_rot = torch.split( + q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pass, k_rot = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) - k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) - k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pass = ( + self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + ) + k_pass, value_states = torch.split( + k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings - if self.config.rope_interleave: # support using interleaved weights for efficiency + if ( + self.config.rope_interleave + ): # support using interleaved weights for efficiency q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) else: q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) @@ -182,20 +204,29 @@ class PLMAttention(nn.Module): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) - if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + if ( + self.config._attn_implementation == "flash_attention_2" + and self.qk_head_dim != self.v_head_dim + ): value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + if self.config._attn_implementation == "sdpa" and kwargs.get( + "output_attentions", False + ): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] attn_output, attn_weights = attention_interface( self, @@ -208,7 +239,10 @@ class PLMAttention(nn.Module): **kwargs, ) - if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + if ( + self.config._attn_implementation == "flash_attention_2" + and self.qk_head_dim != self.v_head_dim + ): attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() @@ -223,7 +257,9 @@ class PLMDecoderLayer(LlamaDecoderLayer, nn.Module): self.self_attn = PLMAttention(config, layer_idx) self.mlp = PLMMLP(config) self.input_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = PLMRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) class PLMPreTrainedModel(LlamaPreTrainedModel): @@ -240,21 +276,27 @@ class PLMPreTrainedModel(LlamaPreTrainedModel): elif isinstance(module, nn.Parameter): module.weight.data.normal_(mean=0.0, std=std) + class PLMForTokenClassification(LlamaForTokenClassification): pass + class PLMForCausalLM(LlamaForCausalLM): pass + class PLMModel(LlamaModel): pass + + class PLMForSequenceClassification(LlamaForSequenceClassification): pass + __all__ = [ "PLMPreTrainedModel", "PLMModel", "PLMForCausalLM", "PLMForSequenceClassification", - "PLMForTokenClassification" + "PLMForTokenClassification", ] diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index f0626648617..1cda4f240a3 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -41,13 +41,12 @@ if is_torch_available(): PLMForCausalLM, PLMModel, ) + # from transformers.models.plm.modeling_plm import ( # PLMRotaryEmbedding, # ) - - class PLMModelTester: def __init__( self, @@ -122,19 +121,33 @@ class PLMModelTester: token_type_ids = None if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + token_type_ids = ids_tensor( + [self.batch_size, self.seq_length], self.type_vocab_size + ) sequence_labels = None token_labels = None choice_labels = None if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + sequence_labels = ids_tensor( + [self.batch_size], self.type_sequence_label_size + ) + token_labels = ids_tensor( + [self.batch_size, self.seq_length], self.num_labels + ) choice_labels = ids_tensor([self.batch_size], self.num_choices) config = self.get_config() - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) def get_config(self): return PLMConfig( @@ -158,14 +171,24 @@ class PLMModelTester: ) def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, ): model = PLMModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask) result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.seq_length, self.hidden_size), + ) def create_and_check_model_as_decoder( self, @@ -195,7 +218,10 @@ class PLMModelTester: encoder_hidden_states=encoder_hidden_states, ) result = model(input_ids, attention_mask=input_mask) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.seq_length, self.hidden_size), + ) def create_and_check_for_causal_lm( self, @@ -213,7 +239,9 @@ class PLMModelTester: model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=token_labels) - self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + self.parent.assertEqual( + result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size) + ) def create_and_check_decoder_model_past_large_inputs( self, @@ -269,13 +297,17 @@ class PLMModelTester: # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() - output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_no_past_slice = output_from_no_past[ + :, -3:, random_slice_idx + ].detach() output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + self.parent.assertTrue( + torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + ) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -293,7 +325,9 @@ class PLMModelTester: @require_torch -class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class PLMModelTest( + ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase +): # breakpoint() all_model_classes = ( ( @@ -390,19 +424,27 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_sdpa_equivalence(self): pass - @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") + @unittest.skip( + "PLM uses MLA so it is not compatible with the standard cache format" + ) def test_beam_search_generate_dict_outputs_use_cache(self): pass - @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") + @unittest.skip( + "PLM uses MLA so it is not compatible with the standard cache format" + ) def test_generate_compilation_all_outputs(self): pass - @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") + @unittest.skip( + "PLM uses MLA so it is not compatible with the standard cache format" + ) def test_generate_compile_model_forward(self): pass - @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") + @unittest.skip( + "PLM uses MLA so it is not compatible with the standard cache format" + ) def test_greedy_generate_dict_outputs_use_cache(self): pass @@ -470,7 +512,6 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) # torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - # @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format") # def test_past_key_values_format(self): # pass @@ -526,8 +567,6 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # ) - - @require_torch_accelerator class PLMIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) @@ -537,7 +576,9 @@ class PLMIntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls): if is_torch_available() and torch.cuda.is_available(): - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + cls.cuda_compute_capability_major_version = ( + torch.cuda.get_device_capability()[0] + ) @slow @require_torch_accelerator @@ -563,29 +604,43 @@ class PLMIntegrationTest(unittest.TestCase): "Simply put, the theory of relativity states that ", "My favorite all time favorite condiment is ketchup.", ] - tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base", use_fast=False) + tokenizer = AutoTokenizer.from_pretrained( + "PLM-Team/PLM-1.8B-Base", use_fast=False + ) model = PLMForCausalLM.from_pretrained( "PLM-Team/PLM-1.8B-Base", device_map=torch_device, torch_dtype=torch.float16 ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) # Dynamic Cache - generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False + ) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) # Static Cache generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + **inputs, + max_new_tokens=NUM_TOKENS_TO_GENERATE, + do_sample=False, + cache_implementation="static" ) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) # Static Cache + compile model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"` - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + model.forward = torch.compile( + model.forward, mode="reduce-overhead", fullgraph=True + ) + generated_ids = model.generate( + **inputs, + max_new_tokens=NUM_TOKENS_TO_GENERATE, + do_sample=False, + cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True ) - static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) From 88840ee45a7d241239e9765a6a2f92d87c15685e Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 06:50:23 +0000 Subject: [PATCH 07/17] ruff format --- src/transformers/models/plm/__init__.py | 4 +- .../models/plm/configuration_plm.py | 1 + src/transformers/models/plm/modular_plm.py | 66 +++++-------------- 3 files changed, 20 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/plm/__init__.py b/src/transformers/models/plm/__init__.py index 389e9e39abc..1d25968f5d2 100644 --- a/src/transformers/models/plm/__init__.py +++ b/src/transformers/models/plm/__init__.py @@ -24,4 +24,6 @@ else: import sys _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, _file, define_import_structure(_file), module_spec=__spec__ + ) diff --git a/src/transformers/models/plm/configuration_plm.py b/src/transformers/models/plm/configuration_plm.py index 756a9184ebc..0f71a4b79f5 100644 --- a/src/transformers/models/plm/configuration_plm.py +++ b/src/transformers/models/plm/configuration_plm.py @@ -14,6 +14,7 @@ # limitations under the License. """PLM model configuration""" + from ...configuration_utils import PretrainedConfig from ...utils import logging diff --git a/src/transformers/models/plm/modular_plm.py b/src/transformers/models/plm/modular_plm.py index bfe333ec922..3027b8dbcc2 100644 --- a/src/transformers/models/plm/modular_plm.py +++ b/src/transformers/models/plm/modular_plm.py @@ -99,9 +99,7 @@ class PLMAttention(nn.Module): super().__init__() self.config = config self.layer_idx = layer_idx - self.num_key_value_groups = ( - config.num_attention_heads // config.num_key_value_heads - ) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout self.num_heads = config.num_attention_heads self.rope_theta = config.rope_theta @@ -114,17 +112,11 @@ class PLMAttention(nn.Module): self.is_causal = True if config.q_lora_rank is not None: - self.q_a_proj = nn.Linear( - config.hidden_size, config.q_lora_rank, bias=config.attention_bias - ) + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) self.q_a_layernorm = PLMRMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear( - config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False - ) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) else: - self.q_proj = nn.Linear( - config.hidden_size, self.num_heads * self.qk_head_dim, bias=False - ) + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( config.hidden_size, @@ -165,34 +157,22 @@ class PLMAttention(nn.Module): ) if self.q_lora_rank is not None: q_states = ( - self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - .view(query_shape) - .transpose(1, 2) + self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) ) else: q_states = self.q_proj(hidden_states).view(query_shape).transpose(1, 2) - q_pass, q_rot = torch.split( - q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - k_pass, k_rot = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pass = ( - self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) - ) - k_pass, value_states = torch.split( - k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings - if ( - self.config.rope_interleave - ): # support using interleaved weights for efficiency + if self.config.rope_interleave: # support using interleaved weights for efficiency q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) else: q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) @@ -204,29 +184,20 @@ class PLMAttention(nn.Module): if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if ( - self.config._attn_implementation == "flash_attention_2" - and self.qk_head_dim != self.v_head_dim - ): + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get( - "output_attentions", False - ): + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: - attention_interface = ALL_ATTENTION_FUNCTIONS[ - self.config._attn_implementation - ] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -239,10 +210,7 @@ class PLMAttention(nn.Module): **kwargs, ) - if ( - self.config._attn_implementation == "flash_attention_2" - and self.qk_head_dim != self.v_head_dim - ): + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() @@ -257,9 +225,7 @@ class PLMDecoderLayer(LlamaDecoderLayer, nn.Module): self.self_attn = PLMAttention(config, layer_idx) self.mlp = PLMMLP(config) self.input_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = PLMRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) + self.post_attention_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class PLMPreTrainedModel(LlamaPreTrainedModel): From c808083e398f06290e0b8b1f78600615e6de3411 Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 06:52:26 +0000 Subject: [PATCH 08/17] ruff format --- src/transformers/models/plm/__init__.py | 4 +- tests/models/plm/test_modeling_plm.py | 86 ++++++------------------- 2 files changed, 22 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/plm/__init__.py b/src/transformers/models/plm/__init__.py index 1d25968f5d2..389e9e39abc 100644 --- a/src/transformers/models/plm/__init__.py +++ b/src/transformers/models/plm/__init__.py @@ -24,6 +24,4 @@ else: import sys _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule( - __name__, _file, define_import_structure(_file), module_spec=__spec__ - ) + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index 1cda4f240a3..81bc567c115 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -121,20 +121,14 @@ class PLMModelTester: token_type_ids = None if self.use_token_type_ids: - token_type_ids = ids_tensor( - [self.batch_size, self.seq_length], self.type_vocab_size - ) + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) sequence_labels = None token_labels = None choice_labels = None if self.use_labels: - sequence_labels = ids_tensor( - [self.batch_size], self.type_sequence_label_size - ) - token_labels = ids_tensor( - [self.batch_size, self.seq_length], self.num_labels - ) + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) config = self.get_config() @@ -239,9 +233,7 @@ class PLMModelTester: model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=token_labels) - self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size) - ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def create_and_check_decoder_model_past_large_inputs( self, @@ -297,17 +289,13 @@ class PLMModelTester: # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() - output_from_no_past_slice = output_from_no_past[ - :, -3:, random_slice_idx - ].detach() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice - self.parent.assertTrue( - torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) - ) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -325,9 +313,7 @@ class PLMModelTester: @require_torch -class PLMModelTest( - ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase -): +class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): # breakpoint() all_model_classes = ( ( @@ -402,21 +388,15 @@ class PLMModelTest( def test_contrastive_generate_low_memory(self): pass - @unittest.skip( - "PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." - ) + @unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_with_static_cache(self): pass - @unittest.skip( - "PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." - ) + @unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_from_inputs_embeds_with_static_cache(self): pass - @unittest.skip( - "PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support." - ) + @unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") def test_generate_continue_from_inputs_embeds(self): pass @@ -424,27 +404,19 @@ class PLMModelTest( def test_sdpa_equivalence(self): pass - @unittest.skip( - "PLM uses MLA so it is not compatible with the standard cache format" - ) + @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") def test_beam_search_generate_dict_outputs_use_cache(self): pass - @unittest.skip( - "PLM uses MLA so it is not compatible with the standard cache format" - ) + @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") def test_generate_compilation_all_outputs(self): pass - @unittest.skip( - "PLM uses MLA so it is not compatible with the standard cache format" - ) + @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") def test_generate_compile_model_forward(self): pass - @unittest.skip( - "PLM uses MLA so it is not compatible with the standard cache format" - ) + @unittest.skip("PLM uses MLA so it is not compatible with the standard cache format") def test_greedy_generate_dict_outputs_use_cache(self): pass @@ -576,9 +548,7 @@ class PLMIntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls): if is_torch_available() and torch.cuda.is_available(): - cls.cuda_compute_capability_major_version = ( - torch.cuda.get_device_capability()[0] - ) + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] @slow @require_torch_accelerator @@ -604,43 +574,29 @@ class PLMIntegrationTest(unittest.TestCase): "Simply put, the theory of relativity states that ", "My favorite all time favorite condiment is ketchup.", ] - tokenizer = AutoTokenizer.from_pretrained( - "PLM-Team/PLM-1.8B-Base", use_fast=False - ) + tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base", use_fast=False) model = PLMForCausalLM.from_pretrained( "PLM-Team/PLM-1.8B-Base", device_map=torch_device, torch_dtype=torch.float16 ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) # Dynamic Cache - generated_ids = model.generate( - **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False - ) + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) # Static Cache generated_ids = model.generate( - **inputs, - max_new_tokens=NUM_TOKENS_TO_GENERATE, - do_sample=False, - cache_implementation="static" + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) # Static Cache + compile model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"` - model.forward = torch.compile( - model.forward, mode="reduce-overhead", fullgraph=True - ) + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) generated_ids = model.generate( - **inputs, - max_new_tokens=NUM_TOKENS_TO_GENERATE, - do_sample=False, - cache_implementation="static" - ) - static_compiled_text = tokenizer.batch_decode( - generated_ids, skip_special_tokens=True + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) From 269a49e5d638a6d615bb3f9e9d25faa8172e1023 Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 07:13:19 +0000 Subject: [PATCH 09/17] docstring PLMConfig --- .../models/plm/configuration_plm.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/plm/configuration_plm.py b/src/transformers/models/plm/configuration_plm.py index 0f71a4b79f5..55a98160ead 100644 --- a/src/transformers/models/plm/configuration_plm.py +++ b/src/transformers/models/plm/configuration_plm.py @@ -33,7 +33,7 @@ class PLMConfig(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 151936): Vocabulary size of the PLM model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`PLMModel`] - hidden_size (`int`, *optional*, defaults to 4096): + hidden_size (`int`, *optional*, defaults to 2048): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 8192): Dimension of the MLP representations. @@ -48,13 +48,13 @@ class PLMConfig(PretrainedConfig): converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + kv_lora_rank (`int`, *optional*, defaults to 512): + q_lora_rank (`int`, *optional*): + qk_rope_head_dim (`int`, *optional*, defaults to 64): + v_head_dim (`int`, *optional*, defaults to 128): + qk_nope_head_dim (`int`, *optional*, defaults to 128): + hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): The non-linear activation function (function or string) in the decoder. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). max_position_embeddings (`int`, *optional*, defaults to 4096): The maximum sequence length that this model might ever be used with. initializer_range (`float`, *optional*, defaults to 0.02): @@ -64,16 +64,22 @@ class PLMConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether the model's input and output word embeddings should be tied. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports normal rope. rope_theta (`float`, *optional*, defaults to 100000.0): The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports normal rope. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + rope_interleave (`bool`, *optional*, defaults to `True`): ```python >>> from transformers import PLMModel, PLMConfig >>> # Initializing a PLM style configuration From 90ce1658c30a250b2b049816136e0d45a7a750ef Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 07:19:34 +0000 Subject: [PATCH 10/17] regenerate modeling --- src/transformers/models/plm/modeling_plm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/plm/modeling_plm.py b/src/transformers/models/plm/modeling_plm.py index 0514633f2e3..7094a89a165 100644 --- a/src/transformers/models/plm/modeling_plm.py +++ b/src/transformers/models/plm/modeling_plm.py @@ -289,7 +289,12 @@ class PLMAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: batch_size, seq_length = hidden_states.shape[:-1] query_shape = (batch_size, seq_length, -1, self.qk_head_dim) - key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + key_shape = ( + batch_size, + seq_length, + -1, + self.qk_nope_head_dim + self.v_head_dim, + ) if self.q_lora_rank is not None: q_states = ( self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2) From 5e1e8fafe79551e372c04d82d77c25f410847dc3 Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 07:49:55 +0000 Subject: [PATCH 11/17] add model markdown doc --- docs/source/en/model_doc/plm.md | 54 +++++++++++++++++++++++++++ tests/models/plm/test_modeling_plm.py | 34 +++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 docs/source/en/model_doc/plm.md diff --git a/docs/source/en/model_doc/plm.md b/docs/source/en/model_doc/plm.md new file mode 100644 index 00000000000..1245a05bdd9 --- /dev/null +++ b/docs/source/en/model_doc/plm.md @@ -0,0 +1,54 @@ + + +# PLM + +## Overview + +To be released with the official model launch. + +### Model Details + +To be released with the official model launch. + + +## Usage tips + +To be released with the official model launch. + +## PLMConfig + +[[autodoc]] PLMConfig + +## PLMModel + +[[autodoc]] PLMModel + - forward + +## PLMForCausalLM + +[[autodoc]] PLMForCausalLM + - forward + +## PLMForSequenceClassification + +[[autodoc]] PLMForSequenceClassification + - forward + +## PLMForTokenClassification + +[[autodoc]] PLMForTokenClassification + - forward \ No newline at end of file diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index 81bc567c115..397c7e1c546 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -40,6 +40,8 @@ if is_torch_available(): from transformers import ( PLMForCausalLM, PLMModel, + PLMForSequenceClassification, + PLMForTokenClassification, ) # from transformers.models.plm.modeling_plm import ( @@ -319,6 +321,8 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ( PLMModel, PLMForCausalLM, + PLMForSequenceClassification, + PLMForTokenClassification, ) if is_torch_available() else () @@ -327,7 +331,10 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, pipeline_model_mapping = ( { "feature-extraction": PLMModel, + "text-classification": PLMForSequenceClassification, + "token-classification": PLMForTokenClassification, "text-generation": PLMForCausalLM, + "zero-shot": PLMForSequenceClassification, } if is_torch_available() else {} @@ -423,6 +430,33 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def test_config(self): self.config_tester.run_common_tests() + def test_PLM_token_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = PLMForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + def test_Qwen2_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = PLMForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + # def test_model(self): # config_and_inputs = self.model_tester.prepare_config_and_inputs() # self.model_tester.create_and_check_model(*config_and_inputs) From ebd137fe03589c313d305892ba103cd25ffecf9e Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 07:54:23 +0000 Subject: [PATCH 12/17] reorganize import --- tests/models/plm/test_modeling_plm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index 397c7e1c546..7ae737dbe2f 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -39,9 +39,10 @@ if is_torch_available(): from transformers import ( PLMForCausalLM, - PLMModel, PLMForSequenceClassification, PLMForTokenClassification, + PLMModel, + ) # from transformers.models.plm.modeling_plm import ( From 062a9ce0264b9b62b09b50831ce2c55c7fd9a3e7 Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 07:58:13 +0000 Subject: [PATCH 13/17] reformat test --- tests/models/plm/test_modeling_plm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index 7ae737dbe2f..8c5c83e2417 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -42,7 +42,6 @@ if is_torch_available(): PLMForSequenceClassification, PLMForTokenClassification, PLMModel, - ) # from transformers.models.plm.modeling_plm import ( From 5d946425500adfad80347af9c4997079ec4197dd Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Sun, 20 Apr 2025 08:10:43 +0000 Subject: [PATCH 14/17] add model link for PLMConfig --- src/transformers/models/plm/configuration_plm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/plm/configuration_plm.py b/src/transformers/models/plm/configuration_plm.py index 55a98160ead..4fbe47459d8 100644 --- a/src/transformers/models/plm/configuration_plm.py +++ b/src/transformers/models/plm/configuration_plm.py @@ -28,7 +28,8 @@ class PLMConfig(PretrainedConfig): PLM model according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the - defaults will yield a similar configuration to that of the PLM model. + defaults will yield a similar configuration to that of + PLM-1.8B-Base [PLM-Team/PLM-1.8B-Base](https://huggingface.co/PLM-Team/PLM-1.8B-Base). Args: vocab_size (`int`, *optional*, defaults to 151936): Vocabulary size of the PLM model. Defines the number of different tokens that can be represented by the From d923bb44054bee2780801f083ef5315b69cc0fa0 Mon Sep 17 00:00:00 2001 From: JiwenJ <3522936020@qq.com> Date: Mon, 21 Apr 2025 03:55:57 +0000 Subject: [PATCH 15/17] change tie_word_embeddings to default false --- src/transformers/models/plm/configuration_plm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/plm/configuration_plm.py b/src/transformers/models/plm/configuration_plm.py index 4fbe47459d8..ffdb1e9beff 100644 --- a/src/transformers/models/plm/configuration_plm.py +++ b/src/transformers/models/plm/configuration_plm.py @@ -113,7 +113,7 @@ class PLMConfig(PretrainedConfig): rms_norm_eps=1e-6, use_cache=True, pretraining_tp=1, - tie_word_embeddings=True, + tie_word_embeddings=False, rope_theta=100000.0, rope_scaling=None, attention_bias=False, From da9c6b4c573381459e0cb60e473b70ce5c424ef0 Mon Sep 17 00:00:00 2001 From: JiwenJ Date: Sun, 4 May 2025 03:37:49 +0000 Subject: [PATCH 16/17] modify 2025 & plm doc --- docs/source/en/model_doc/plm.md | 39 +++++- src/transformers/models/plm/__init__.py | 3 +- .../models/plm/configuration_plm.py | 2 +- src/transformers/models/plm/modular_plm.py | 14 ++ tests/models/plm/test_modeling_plm.py | 122 +----------------- 5 files changed, 52 insertions(+), 128 deletions(-) diff --git a/docs/source/en/model_doc/plm.md b/docs/source/en/model_doc/plm.md index 1245a05bdd9..cf8860461af 100644 --- a/docs/source/en/model_doc/plm.md +++ b/docs/source/en/model_doc/plm.md @@ -1,4 +1,4 @@ - # PLM +
+PyTorch +FlashAttention +SDPA +
## Overview -To be released with the official model launch. +The PLM model was proposed in [PLM: Efficient Peripheral Language Models Hardware-Co-Designed for Ubiquitous Computing](https://arxiv.org/abs/2503.12167) by PLM-Team. -### Model Details +### Summary -To be released with the official model launch. +The PLM (Peripheral Language Model) series introduces a novel model architecture to peripheral computing by delivering powerful language capabilities within the constraints of resource-limited devices. Through modeling and system co-design strategy, PLM optimizes model performance and fits edge system requirements, PLM employs Multi-head Latent Attention and squared ReLU activation to achieve sparsity, significantly reducing memory footprint and computational demands. Coupled with a meticulously crafted training regimen using curated datasets and a Warmup-Stable-Decay-Constant learning rate scheduler, PLM demonstrates superior performance compared to existing small language models, all while maintaining the lowest activated parameters, making it ideally suited for deployment on diverse peripheral platforms like mobile phones and Raspberry Pis. ## Usage tips -To be released with the official model launch. +Ensure your Transformers library version is up-to-date. PLM requires Transformers>=4.51.3 for full support. + + +`PLM-1.8B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/PLM-Team/PLM-1.8B-Instruct) + + +In the following, we demonstrate how to use it for inference + +```python +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + +# Load model and tokenizer +tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Instruct") +model = AutoModelForCausalLM.from_pretrained("PLM-Team/PLM-1.8B-Instruct", torch_dtype=torch.bfloat16) + +# Input text +input_text = "Tell me something about reinforcement learning." +inputs = tokenizer(input_text, return_tensors="pt") + +# Completion +output = model.generate(inputs["input_ids"], max_new_tokens=100) +print(tokenizer.decode(output[0], skip_special_tokens=True)) +``` + ## PLMConfig diff --git a/src/transformers/models/plm/__init__.py b/src/transformers/models/plm/__init__.py index 389e9e39abc..eae0d284a7c 100644 --- a/src/transformers/models/plm/__init__.py +++ b/src/transformers/models/plm/__init__.py @@ -1,4 +1,5 @@ -# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# coding=utf-8 +# Copyright 2025 The PLM team and the HuggingFace Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/plm/configuration_plm.py b/src/transformers/models/plm/configuration_plm.py index ffdb1e9beff..b22ebd1bcd6 100644 --- a/src/transformers/models/plm/configuration_plm.py +++ b/src/transformers/models/plm/configuration_plm.py @@ -70,7 +70,7 @@ class PLMConfig(PretrainedConfig): document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is necessary to ensure exact reproducibility of the pretraining results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `True`): + tie_word_embeddings (`bool`, *optional*, defaults to `False`): Whether the model's input and output word embeddings should be tied. rope_theta (`float`, *optional*, defaults to 100000.0): The base period of the RoPE embeddings. diff --git a/src/transformers/models/plm/modular_plm.py b/src/transformers/models/plm/modular_plm.py index 3027b8dbcc2..c0b41a8d39c 100644 --- a/src/transformers/models/plm/modular_plm.py +++ b/src/transformers/models/plm/modular_plm.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2025 The PLM team and the HuggingFace Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, Optional, Tuple import torch diff --git a/tests/models/plm/test_modeling_plm.py b/tests/models/plm/test_modeling_plm.py index 8c5c83e2417..f77f033ecc9 100644 --- a/tests/models/plm/test_modeling_plm.py +++ b/tests/models/plm/test_modeling_plm.py @@ -44,11 +44,6 @@ if is_torch_available(): PLMModel, ) - # from transformers.models.plm.modeling_plm import ( - # PLMRotaryEmbedding, - # ) - - class PLMModelTester: def __init__( self, @@ -112,7 +107,6 @@ class PLMModelTester: self.num_choices = num_choices self.pad_token_id = pad_token_id self.scope = scope - # breakpoint() def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -445,7 +439,7 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), ) - def test_Qwen2_sequence_classification_model(self): + def test_PLM_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 input_ids = input_dict["input_ids"] @@ -457,120 +451,6 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - # def test_model(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # self.model_tester.create_and_check_model(*config_and_inputs) - - # def test_model_various_embeddings(self): - # config_and_inputs = self.model_tester.prepare_config_and_inputs() - # for type in ["absolute", "relative_key", "relative_key_query"]: - # config_and_inputs[0].position_embedding_type = type - # self.model_tester.create_and_check_model(*config_and_inputs) - - # @parameterized.expand([("yarn",)]) - # def test_model_rope_scaling_from_config(self, scaling_type): - # config, _ = self.model_tester.prepare_config_and_inputs_for_common() - # short_input = ids_tensor([1, 10], config.vocab_size) - # long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - # set_seed(42) # Fixed seed at init time so the two models get the same random weights - # original_model = PLMModel(config) - # original_model.to(torch_device) - # original_model.eval() - # original_short_output = original_model(short_input).last_hidden_state - # original_long_output = original_model(long_input).last_hidden_state - - # set_seed(42) # Fixed seed at init time so the two models get the same random weights - # config.rope_scaling = {"type": scaling_type, "factor": 10.0} - # scaled_model = PLMModel(config) - # scaled_model.to(torch_device) - # scaled_model.eval() - # scaled_short_output = scaled_model(short_input).last_hidden_state - # scaled_long_output = scaled_model(long_input).last_hidden_state - - # # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # # maximum sequence length, so the outputs for the short input should match. - # if scaling_type == "dynamic": - # torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - # else: - # self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # # The output should be different for long inputs - # self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - # def test_model_rope_scaling(self): - # config, _ = self.model_tester.prepare_config_and_inputs_for_common() - # scaling_factor = 10 - # short_input_length = 10 - # long_input_length = int(config.max_position_embeddings * 1.5) - - # # Inputs - # x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device - # position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - # position_ids_short = position_ids_short.unsqueeze(0) - # position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - # position_ids_long = position_ids_long.unsqueeze(0) - - # # Sanity check original RoPE - # original_rope = PLMRotaryEmbedding(config=config).to(torch_device) - # original_cos_short, original_sin_short = original_rope(x, position_ids_short) - # original_cos_long, original_sin_long = original_rope(x, position_ids_long) - # torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - # torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format") - # def test_past_key_values_format(self): - # pass - - # @require_torch_sdpa - # @slow - # def test_eager_matches_sdpa_generate(self): - # """ - # Overwritting the common test as the test is flaky on tiny models - # """ - # max_new_tokens = 30 - - # tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base") - - # model_sdpa = PLMForCausalLM.from_pretrained( - # "PLM-Team/PLM-1.8B-Base", - # torch_dtype=torch.float16, - # low_cpu_mem_usage=True, - # trust_remote_code=True, - # ).to(torch_device) - - # self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - # model_eager = PLMForCausalLM.from_pretrained( - # "PLM-Team/PLM-1.8B-Base", - # torch_dtype=torch.float16, - # low_cpu_mem_usage=True, - # attn_implementation="eager", - # trust_remote_code=True, - # ).to(torch_device) - # breakpoint() - # self.assertTrue(model_eager.config._attn_implementation == "eager") - - # texts = [ - # "hi here's a longer context, getting longer and", - # "Hello this is a very long sentence my friend, very long for real", - # "Today I am in Paris and", - # ] - - # for padding_side in ["left", "right"]: - # tokenizer.padding_side = padding_side - # tokenizer.pad_token = tokenizer.eos_token - - # inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) - - # res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - # res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) - - # with self.subTest(f"{padding_side}"): - # torch.testing.assert_close( - # res_eager, - # res_sdpa, - # msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", - # ) @require_torch_accelerator From ef3d1e548ab71bca13cf90c9299db5efd22ccb84 Mon Sep 17 00:00:00 2001 From: JiwenJ Date: Sun, 4 May 2025 08:07:57 +0000 Subject: [PATCH 17/17] modify mlp & decoder --- src/transformers/models/plm/modular_plm.py | 28 ++++++---------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/plm/modular_plm.py b/src/transformers/models/plm/modular_plm.py index c0b41a8d39c..d8176a7e4e7 100644 --- a/src/transformers/models/plm/modular_plm.py +++ b/src/transformers/models/plm/modular_plm.py @@ -19,8 +19,9 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from ...activations import ACT2FN + from ...cache_utils import Cache +from ..clip.modeling_clip import CLIPMLP from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack @@ -90,20 +91,8 @@ def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze return q_embed, k_embed -class PLMMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - h = self.up_proj(hidden_state) - h = self.act_fn(h) - h = self.down_proj(h) - return h +class PLMMLP(CLIPMLP): + pass class PLMAttention(nn.Module): @@ -232,14 +221,11 @@ class PLMAttention(nn.Module): return attn_output, attn_weights -class PLMDecoderLayer(LlamaDecoderLayer, nn.Module): +class PLMDecoderLayer(LlamaDecoderLayer): def __init__(self, config: PLMConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = PLMAttention(config, layer_idx) + super().__init__(config, layer_idx) + self.self_attn = PLMAttention(config=config, layer_idx=layer_idx) self.mlp = PLMMLP(config) - self.input_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) class PLMPreTrainedModel(LlamaPreTrainedModel):