transformers/tests/models/plm/test_modeling_plm.py
2025-05-04 03:37:49 +00:00

517 lines
19 KiB
Python

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch PLM model."""
import unittest
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, PLMConfig, is_torch_available
from transformers.testing_utils import (
require_read_token,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
PLMForCausalLM,
PLMForSequenceClassification,
PLMForTokenClassification,
PLMModel,
)
class PLMModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
intermediate_size=37,
num_hidden_layers=5,
num_attention_heads=4,
num_key_value_heads=4,
kv_lora_rank=16,
q_lora_rank=32,
qk_rope_head_dim=16,
v_head_dim=32,
qk_nope_head_dim=32,
n_group=2,
first_k_dense_replace=2,
norm_topk_prob=True,
hidden_act="relu2",
max_position_embeddings=512,
initializer_range=0.02,
attention_probs_dropout_prob=0.1,
type_vocab_size=16,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
)
def get_config(self):
return PLMConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_key_value_heads,
kv_lora_rank=self.kv_lora_rank,
q_lora_rank=self.q_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
qk_nope_head_dim=self.qk_nope_head_dim,
hidden_act=self.hidden_act,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
use_cache=True,
pad_token_id=self.pad_token_id,
attention_dropout=self.attention_probs_dropout_prob,
)
def create_and_check_model(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
):
model = PLMModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.seq_length, self.hidden_size),
)
def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = PLMModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.seq_length, self.hidden_size),
)
def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = PLMForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = PLMForCausalLM(config=config)
model.to(torch_device)
model.eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class PLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
# breakpoint()
all_model_classes = (
(
PLMModel,
PLMForCausalLM,
PLMForSequenceClassification,
PLMForTokenClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (PLMForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": PLMModel,
"text-classification": PLMForSequenceClassification,
"token-classification": PLMForTokenClassification,
"text-generation": PLMForCausalLM,
"zero-shot": PLMForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = PLMForCausalLM if is_torch_available() else None
def setUp(self):
self.model_tester = PLMModelTester(self)
self.config_tester = ConfigTester(self, config_class=PLMConfig, hidden_size=37)
@unittest.skip("Failing because of unique cache (HybridCache)")
def test_model_outputs_equivalence(self, **kwargs):
pass
@parameterized.expand([("random",), ("same",)])
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass
@unittest.skip("PLM has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("PLM has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support continue from past kv")
def test_generate_continue_from_past_key_values(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support low_memory generation")
def test_beam_search_low_memory(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support contrastive generation")
def test_contrastive_generate_low_memory(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_with_static_cache(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
@unittest.skip("PLM has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_continue_from_inputs_embeds(self):
pass
@unittest.skip("PLM's eager attn/sdpa attn outputs are expected to be different")
def test_sdpa_equivalence(self):
pass
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
def test_generate_compilation_all_outputs(self):
pass
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
def test_generate_compile_model_forward(self):
pass
@unittest.skip("PLM uses MLA so it is not compatible with the standard cache format")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
def test_config(self):
self.config_tester.run_common_tests()
def test_PLM_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = PLMForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
self.assertEqual(
result.logits.shape,
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
)
def test_PLM_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = PLMForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@require_torch_accelerator
class PLMIntegrationTest(unittest.TestCase):
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@slow
@require_torch_accelerator
@require_read_token
def test_compile_static_cache(self):
# `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2
# work as intended. See https://github.com/pytorch/pytorch/issues/121943
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativ",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = AutoTokenizer.from_pretrained("PLM-Team/PLM-1.8B-Base", use_fast=False)
model = PLMForCausalLM.from_pretrained(
"PLM-Team/PLM-1.8B-Base", device_map=torch_device, torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile
model._cache = None # clear cache object, initialized when we pass `cache_implementation="static"`
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)