mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
test code format
This commit is contained in:
parent
1e9e950e35
commit
25cd37ab4e
@ -24,7 +24,7 @@ logger = logging.get_logger(__name__)
|
||||
class PLMConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`PLMModel`]. It is used to instantiate a
|
||||
PLM model according to the specified arguments, defining the model architecture.
|
||||
PLM model according to the specified arguments, defining the model architecture.
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the PLM model.
|
||||
@ -147,4 +147,4 @@ class PLMConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
__all__ = ["PLMConfig"]
|
||||
__all__ = ["PLMConfig"]
|
||||
|
@ -1,4 +1,3 @@
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -15,12 +14,12 @@ from ...utils import logging
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaForTokenClassification,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaForSequenceClassification,
|
||||
LlamaRotaryEmbedding,
|
||||
LlamaForTokenClassification,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
rotate_half,
|
||||
@ -87,7 +86,7 @@ class PLMMLP(nn.Module):
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
|
||||
def forward(self, hidden_state):
|
||||
h = self.up_proj(hidden_state)
|
||||
h = self.act_fn(h)
|
||||
@ -121,7 +120,7 @@ class PLMAttention(nn.Module):
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
|
||||
else:
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
|
||||
|
||||
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
config.hidden_size,
|
||||
@ -215,7 +214,6 @@ class PLMAttention(nn.Module):
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
||||
class PLMDecoderLayer(LlamaDecoderLayer, nn.Module):
|
||||
@ -259,4 +257,4 @@ __all__ = [
|
||||
"PLMForCausalLM",
|
||||
"PLMForSequenceClassification",
|
||||
"PLMForTokenClassification"
|
||||
]
|
||||
]
|
||||
|
@ -19,12 +19,11 @@ import unittest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, PLMConfig, is_torch_available, set_seed
|
||||
from transformers import AutoTokenizer, PLMConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -42,9 +41,9 @@ if is_torch_available():
|
||||
PLMForCausalLM,
|
||||
PLMModel,
|
||||
)
|
||||
from transformers.models.plm.modeling_plm import (
|
||||
PLMRotaryEmbedding,
|
||||
)
|
||||
# from transformers.models.plm.modeling_plm import (
|
||||
# PLMRotaryEmbedding,
|
||||
# )
|
||||
|
||||
|
||||
|
||||
@ -291,10 +290,6 @@ class PLMModelTester:
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -498,7 +493,6 @@ class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# ).to(torch_device)
|
||||
|
||||
# self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
# model_eager = PLMForCausalLM.from_pretrained(
|
||||
# "PLM-Team/PLM-1.8B-Base",
|
||||
# torch_dtype=torch.float16,
|
||||
|
Loading…
Reference in New Issue
Block a user