test code format

This commit is contained in:
JiwenJ 2025-04-20 06:25:58 +00:00
parent 1e9e950e35
commit 25cd37ab4e
3 changed files with 11 additions and 19 deletions

View File

@ -24,7 +24,7 @@ logger = logging.get_logger(__name__)
class PLMConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PLMModel`]. It is used to instantiate a
PLM model according to the specified arguments, defining the model architecture.
PLM model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the
defaults will yield a similar configuration to that of the PLM model.
@ -147,4 +147,4 @@ class PLMConfig(PretrainedConfig):
**kwargs,
)
__all__ = ["PLMConfig"]
__all__ = ["PLMConfig"]

View File

@ -1,4 +1,3 @@
import math
from typing import Callable, Optional, Tuple
import torch
@ -15,12 +14,12 @@ from ...utils import logging
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaForSequenceClassification,
LlamaRotaryEmbedding,
LlamaForTokenClassification,
apply_rotary_pos_emb,
eager_attention_forward,
rotate_half,
@ -87,7 +86,7 @@ class PLMMLP(nn.Module):
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
h = self.up_proj(hidden_state)
h = self.act_fn(h)
@ -121,7 +120,7 @@ class PLMAttention(nn.Module):
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
else:
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
self.kv_a_proj_with_mqa = nn.Linear(
config.hidden_size,
@ -215,7 +214,6 @@ class PLMAttention(nn.Module):
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class PLMDecoderLayer(LlamaDecoderLayer, nn.Module):
@ -259,4 +257,4 @@ __all__ = [
"PLMForCausalLM",
"PLMForSequenceClassification",
"PLMForTokenClassification"
]
]

View File

@ -19,12 +19,11 @@ import unittest
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, PLMConfig, is_torch_available, set_seed
from transformers import AutoTokenizer, PLMConfig, is_torch_available
from transformers.testing_utils import (
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
slow,
torch_device,
)
@ -42,9 +41,9 @@ if is_torch_available():
PLMForCausalLM,
PLMModel,
)
from transformers.models.plm.modeling_plm import (
PLMRotaryEmbedding,
)
# from transformers.models.plm.modeling_plm import (
# PLMRotaryEmbedding,
# )
@ -291,10 +290,6 @@ class PLMModelTester:
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
@ -498,7 +493,6 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# ).to(torch_device)
# self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
# model_eager = PLMForCausalLM.from_pretrained(
# "PLM-Team/PLM-1.8B-Base",
# torch_dtype=torch.float16,