This commit is contained in:
JiwenJ 2025-04-20 05:28:48 +00:00
parent 6daa3eeba5
commit ef97fe7e0e
9 changed files with 2150 additions and 0 deletions

View File

@ -229,6 +229,7 @@ if TYPE_CHECKING:
from .pix2struct import *
from .pixtral import *
from .plbart import *
from .plm import *
from .poolformer import *
from .pop2piano import *
from .prophetnet import *

View File

@ -250,6 +250,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("pix2struct", "Pix2StructConfig"),
("pixtral", "PixtralVisionConfig"),
("plbart", "PLBartConfig"),
("plm", "PLMConfig"),
("poolformer", "PoolFormerConfig"),
("pop2piano", "Pop2PianoConfig"),
("prompt_depth_anything", "PromptDepthAnythingConfig"),
@ -621,6 +622,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("pix2struct", "Pix2Struct"),
("pixtral", "Pixtral"),
("plbart", "PLBart"),
("plm", "PLM"),
("poolformer", "PoolFormer"),
("pop2piano", "Pop2Piano"),
("prompt_depth_anything", "PromptDepthAnything"),

View File

@ -228,6 +228,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("phimoe", "PhimoeModel"),
("pixtral", "PixtralVisionModel"),
("plbart", "PLBartModel"),
("plm", "PLMModel"),
("poolformer", "PoolFormerModel"),
("prophetnet", "ProphetNetModel"),
("pvt", "PvtModel"),
@ -587,6 +588,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("phi4_multimodal", "Phi4MultimodalForCausalLM"),
("phimoe", "PhimoeForCausalLM"),
("plbart", "PLBartForCausalLM"),
("plm", "PLMForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
("qwen2", "Qwen2ForCausalLM"),
@ -1095,6 +1097,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("phi3", "Phi3ForSequenceClassification"),
("phimoe", "PhimoeForSequenceClassification"),
("plbart", "PLBartForSequenceClassification"),
("plm", "PLMForSequenceClassification"),
("qdqbert", "QDQBertForSequenceClassification"),
("qwen2", "Qwen2ForSequenceClassification"),
("qwen2_moe", "Qwen2MoeForSequenceClassification"),
@ -1285,6 +1288,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("persimmon", "PersimmonForTokenClassification"),
("phi", "PhiForTokenClassification"),
("phi3", "Phi3ForTokenClassification"),
("plm", "PLMForTokenClassification"),
("qdqbert", "QDQBertForTokenClassification"),
("qwen2", "Qwen2ForTokenClassification"),
("qwen2_moe", "Qwen2MoeForTokenClassification"),

View File

@ -0,0 +1,27 @@
# Copyright 2024 The Qwen Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_plm import *
from .modeling_plm import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,150 @@
# coding=utf-8
# Copyright 2025 The PLM team and The HuggingFace Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PLM model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class PLMConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PLMModel`]. It is used to instantiate a
PLM model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the
defaults will yield a similar configuration to that of the PLM model.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the PLM model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`PLMModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether the model's input and output word embeddings should be tied.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports normal rope.
rope_theta (`float`, *optional*, defaults to 100000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import PLMModel, PLMConfig
>>> # Initializing a PLM style configuration
>>> configuration = PLMConfig()
>>> # Initializing a model from the PLM style configuration
>>> model = PLMModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "plm"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=32,
num_attention_heads=16,
num_key_value_heads=16,
kv_lora_rank = 512,
q_lora_rank = None,
qk_rope_head_dim = 64,
v_head_dim = 128,
qk_nope_head_dim = 128,
hidden_act="relu2",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pretraining_tp=1,
tie_word_embeddings=True,
rope_theta=100000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
rope_interleave=True,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.rope_interleave = rope_interleave
self.head_dim = qk_rope_head_dim
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["PLMConfig"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,262 @@
import math
from typing import Callable, Optional, Tuple
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import logging
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaForSequenceClassification,
LlamaRotaryEmbedding,
LlamaForTokenClassification,
apply_rotary_pos_emb,
eager_attention_forward,
rotate_half,
)
from .configuration_plm import PLMConfig
logger = logging.get_logger(__name__)
class PLMRMSNorm(LlamaRMSNorm):
pass
class PLMRotaryEmbedding(LlamaRotaryEmbedding):
pass
def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""
TODO let's just use the original freqcis computation to not have the view
transpose + reshape! This is not optimized!
Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class PLMMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
h = self.up_proj(hidden_state)
h = self.act_fn(h)
h = self.down_proj(h)
return h
class PLMAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: PLMConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.attention_dropout = config.attention_dropout
self.num_heads = config.num_attention_heads
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_head_dim = config.qk_head_dim
self.is_causal = True
if config.q_lora_rank is not None:
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.q_a_layernorm = PLMRMSNorm(config.q_lora_rank)
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
else:
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
self.kv_a_proj_with_mqa = nn.Linear(
config.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = PLMRMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
config.hidden_size,
bias=config.attention_bias,
)
self.scaling = self.qk_head_dim ** (-0.5)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
batch_size, seq_length = hidden_states.shape[:-1]
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
if self.q_lora_rank is not None:
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
else:
q_states = self.q_proj(hidden_states).view(query_shape).transpose(1, 2)
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2)
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
cos, sin = position_embeddings
if self.config.rope_interleave: # support using interleaved weights for efficiency
q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin)
else:
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
query_states = torch.cat((q_pass, q_rot), dim=-1)
key_states = torch.cat((k_pass, k_rot), dim=-1)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class PLMDecoderLayer(LlamaDecoderLayer, nn.Module):
def __init__(self, config: PLMConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PLMAttention(config, layer_idx)
self.mlp = PLMMLP(config)
self.input_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = PLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class PLMPreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.Parameter):
module.weight.data.normal_(mean=0.0, std=std)
class PLMForTokenClassification(LlamaForTokenClassification):
pass
class PLMForCausalLM(LlamaForCausalLM):
pass
class PLMModel(LlamaModel):
pass
class PLMForSequenceClassification(LlamaForSequenceClassification):
pass
__all__ = [
"PLMPreTrainedModel",
"PLMModel",
"PLMForCausalLM",
"PLMForSequenceClassification",
"PLMForTokenClassification"
]

View File

View File

@ -0,0 +1,597 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch PLM model."""
import unittest
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, PLMConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
PLMForCausalLM,
PLMModel,
)
from transformers.models.plm.modeling_plm import (
PLMRotaryEmbedding,
)
class PLMModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
intermediate_size=37,
num_hidden_layers=5,
num_attention_heads=4,
num_key_value_heads=4,
kv_lora_rank=16,
q_lora_rank=32,
qk_rope_head_dim=16,
v_head_dim=32,
qk_nope_head_dim=32,
n_group=2,
first_k_dense_replace=2,
norm_topk_prob=True,
hidden_act="relu2",
max_position_embeddings=512,
initializer_range=0.02,
attention_probs_dropout_prob=0.1,
type_vocab_size=16,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
# breakpoint()
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def get_config(self):
return PLMConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
kv_lora_rank=self.kv_lora_rank,
q_lora_rank=self.q_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
qk_nope_head_dim=self.qk_nope_head_dim,
hidden_act=self.hidden_act,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
use_cache=True,
pad_token_id=self.pad_token_id,
attention_dropout=self.attention_probs_dropout_prob,
)
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = PLMModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = PLMModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = PLMForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = PLMForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
# breakpoint()
all_model_classes = (
(
PLMModel,
PLMForCausalLM,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (PLMForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": PLMModel,
"text-generation": PLMForCausalLM,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = PLMForCausalLM if is_torch_available() else None
def setUp(self):
self.model_tester = PLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=PLMConfig, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("PLM has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip(
"PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_with_static_cache(self):
pass
@unittest.skip(
"PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip(
"PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip("PLM's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self):
pass
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
def test_generate_compilation_all_outputs(self):
pass
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
def test_generate_compile_model_forward(self):
pass
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
# breakpoint()
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# self.model_tester.create_and_check_model(*config_and_inputs)
# def test_model_various_embeddings(self):
# config_and_inputs = self.model_tester.prepare_config_and_inputs()
# for type in ["absolute", "relative_key", "relative_key_query"]:
# config_and_inputs[0].position_embedding_type = type
# self.model_tester.create_and_check_model(*config_and_inputs)
# @parameterized.expand([("yarn",)])
# def test_model_rope_scaling_from_config(self, scaling_type):
# config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# short_input = ids_tensor([1, 10], config.vocab_size)
# long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
# set_seed(42) # Fixed seed at init time so the two models get the same random weights
# original_model = PLMModel(config)
# original_model.to(torch_device)
# original_model.eval()
# original_short_output = original_model(short_input).last_hidden_state
# original_long_output = original_model(long_input).last_hidden_state
# set_seed(42) # Fixed seed at init time so the two models get the same random weights
# config.rope_scaling = {"type": scaling_type, "factor": 10.0}
# scaled_model = PLMModel(config)
# scaled_model.to(torch_device)
# scaled_model.eval()
# scaled_short_output = scaled_model(short_input).last_hidden_state
# scaled_long_output = scaled_model(long_input).last_hidden_state
# # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# # maximum sequence length, so the outputs for the short input should match.
# if scaling_type == "dynamic":
# torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
# else:
# self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# # The output should be different for long inputs
# self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
# def test_model_rope_scaling(self):
# config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# scaling_factor = 10
# short_input_length = 10
# long_input_length = int(config.max_position_embeddings * 1.5)
# # Inputs
# x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
# position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
# position_ids_short = position_ids_short.unsqueeze(0)
# position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
# position_ids_long = position_ids_long.unsqueeze(0)
# # Sanity check original RoPE
# original_rope = PLMRotaryEmbedding(config=config).to(torch_device)
# original_cos_short, original_sin_short = original_rope(x, position_ids_short)
# original_cos_long, original_sin_long = original_rope(x, position_ids_long)
# torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
# torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
# @unittest.skip(reason="PLM uses MLA on all models so the KV cache is a non standard format")
# def test_past_key_values_format(self):
# pass
# @require_torch_sdpa
# @slow
# def test_eager_matches_sdpa_generate(self):
# """
# Overwritting the common test as the test is flaky on tiny models
# """
# max_new_tokens = 30
# tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base")
# model_sdpa = PLMForCausalLM.from_pretrained(
# "PLM-Team/PLM-1.8B-Base",
# torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# ).to(torch_device)
# self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
# model_eager = PLMForCausalLM.from_pretrained(
# "PLM-Team/PLM-1.8B-Base",
# torch_dtype=torch.float16,
# low_cpu_mem_usage=True,
# attn_implementation="eager",
# trust_remote_code=True,
# ).to(torch_device)
# breakpoint()
# self.assertTrue(model_eager.config._attn_implementation == "eager")
# texts = [
# "hi here's a longer context, getting longer and",
# "Hello this is a very long sentence my friend, very long for real",
# "Today I am in Paris and",
# ]
# for padding_side in ["left", "right"]:
# tokenizer.padding_side = padding_side
# tokenizer.pad_token = tokenizer.eos_token
# inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
# res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
# res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
# with self.subTest(f"{padding_side}"):
# torch.testing.assert_close(
# res_eager,
# res_sdpa,
# msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
# )
@require_torch_accelerator
class PLMIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmeth#od
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@slow
@require_torch_accelerator
@require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base", use_fast=False)
model = PLMForCausalLM.from_pretrained(
"PLM-Team/PLM-1.8B-Base", device_map=torch_device, torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)