mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 14:29:01 +06:00
🚨 🚨 Inherited CausalLM Tests (#37590)
* stash commit * Experiment 1: Try just Gemma * Experiment 1: Just try Gemma * make fixup * Trigger tests * stash commit * Try adding Gemma3 as well * make fixup * Correct attrib names * Correct pipeline model mapping * Add in all_model_classes for Gemma1 again * Move the pipeline model mapping around again * make fixup * Revert Gemma3 changes since it's a VLM * Let's try Falcon * Correct attributes * Correct attributes * Let's try just overriding get_config() for now * Do Nemotron too * And Llama! * Do llama/persimmon * Correctly skip tests * Fix Persimmon * Include Phimoe * Fix Gemma2 * Set model_tester_class correctly * Add GLM * More models! * models models models * make fixup * Add Qwen3 + Qwen3MoE * Correct import * make fixup * Add the QuestionAnswering classes * Add the QuestionAnswering classes * Move pipeline mapping to the right place * Jetmoe too * Stop RoPE testing models with no RoPE * Fix up JetMOE a bit * Fix up JetMOE a bit * Can we just force pad_token_id all the time? * make fixup * fix starcoder2 * Move pipeline mapping * Fix RoPE skipping * Fix RecurrentGemma tests * Fix Falcon tests * Add MoE attributes * Fix values for RoPE testing * Make sure we set bos_token_id and eos_token_id in an appropriate range * make fixup * Fix GLM4 * Add mamba attributes * Revert bits of JetMOE * Re-add the JetMOE skips * Update tests/causal_lm_tester.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add licence --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
d5f992f5e6
commit
53fb245eb6
479
tests/causal_lm_tester.py
Normal file
479
tests/causal_lm_tester.py
Normal file
@ -0,0 +1,479 @@
|
|||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from transformers import set_seed
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_flaky,
|
||||||
|
require_flash_attn,
|
||||||
|
require_torch_accelerator,
|
||||||
|
require_torch_gpu,
|
||||||
|
require_torch_sdpa,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_common import (
|
||||||
|
GenerationTesterMixin,
|
||||||
|
ModelTesterMixin,
|
||||||
|
ids_tensor,
|
||||||
|
is_torch_available,
|
||||||
|
require_torch,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
from .test_pipeline_mixin import PipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class CausalLMModelTester:
|
||||||
|
_required_attributes = ("base_model_class", "config_class", "causal_lm_class")
|
||||||
|
forced_config_args = [
|
||||||
|
"pad_token_id"
|
||||||
|
] # Arguments that should be passed to the config class even if not in its signature
|
||||||
|
config_class = None
|
||||||
|
base_model_class = None
|
||||||
|
causal_lm_class = None
|
||||||
|
sequence_classification_class = None
|
||||||
|
token_classification_class = None
|
||||||
|
question_answering_class = None
|
||||||
|
|
||||||
|
def _verify_model_attributes(self):
|
||||||
|
for required_attribute in self._required_attributes:
|
||||||
|
if getattr(self, required_attribute) is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have inherited from CausalLMModelTester but did not set the {required_attribute} attribute."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_model_classes(self):
|
||||||
|
return [
|
||||||
|
model_class
|
||||||
|
for model_class in (
|
||||||
|
self.base_model_class,
|
||||||
|
self.causal_lm_class,
|
||||||
|
self.sequence_classification_class,
|
||||||
|
self.token_classification_class,
|
||||||
|
)
|
||||||
|
if model_class is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=7,
|
||||||
|
is_training=True,
|
||||||
|
use_input_mask=True,
|
||||||
|
use_token_type_ids=False,
|
||||||
|
use_labels=True,
|
||||||
|
vocab_size=99,
|
||||||
|
hidden_size=32,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
num_attention_heads=2,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
intermediate_size=37,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_probs_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=512,
|
||||||
|
type_vocab_size=16,
|
||||||
|
type_sequence_label_size=2,
|
||||||
|
initializer_range=0.02,
|
||||||
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
is_decoder=False,
|
||||||
|
scope=None,
|
||||||
|
expert_interval=1,
|
||||||
|
moe_intermediate_size=12,
|
||||||
|
shared_expert_intermediate_size=36,
|
||||||
|
shared_expert_gate=True,
|
||||||
|
num_experts_per_tok=2,
|
||||||
|
num_experts=8,
|
||||||
|
mamba_n_groups=1,
|
||||||
|
mamba_n_heads=16,
|
||||||
|
mamba_d_state=16,
|
||||||
|
mamba_d_conv=4,
|
||||||
|
mamba_expand=2,
|
||||||
|
mamba_chunk_size=16,
|
||||||
|
):
|
||||||
|
self._verify_model_attributes()
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_token_type_ids = use_token_type_ids
|
||||||
|
self.use_labels = use_labels
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.type_vocab_size = type_vocab_size
|
||||||
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
|
self.pad_token_id = pad_token_id
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
self.scope = scope
|
||||||
|
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||||
|
self.is_decoder = is_decoder
|
||||||
|
self.expert_interval = expert_interval
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||||
|
self.shared_expert_gate = shared_expert_gate
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.mamba_n_groups = mamba_n_groups
|
||||||
|
self.mamba_n_heads = mamba_n_heads
|
||||||
|
self.mamba_d_state = mamba_d_state
|
||||||
|
self.mamba_d_conv = mamba_d_conv
|
||||||
|
self.mamba_expand = mamba_expand
|
||||||
|
self.mamba_chunk_size = mamba_chunk_size
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
input_mask = None
|
||||||
|
if self.use_input_mask:
|
||||||
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||||
|
|
||||||
|
token_type_ids = None
|
||||||
|
if self.use_token_type_ids:
|
||||||
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||||
|
|
||||||
|
sequence_labels = None
|
||||||
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
kwarg_names = list(signature(self.config_class.__init__).parameters.keys())
|
||||||
|
kwargs = {
|
||||||
|
k: getattr(self, k) for k in kwarg_names + self.forced_config_args if hasattr(self, k) and k != "self"
|
||||||
|
}
|
||||||
|
return self.config_class(**kwargs)
|
||||||
|
|
||||||
|
def create_and_check_model(
|
||||||
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
model = self.base_model_class(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=input_mask)
|
||||||
|
result = model(input_ids)
|
||||||
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = config_and_inputs
|
||||||
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin):
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
model_tester_class = None
|
||||||
|
all_model_classes = None
|
||||||
|
rotary_embedding_layer = None # Enables RoPE tests if set
|
||||||
|
pipeline_model_mapping = None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
if self.model_tester_class is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You have inherited from CausalLMModelTest but did not set the model_tester_class attribute."
|
||||||
|
)
|
||||||
|
self.model_tester = self.model_tester_class(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class)
|
||||||
|
if self.all_model_classes is None:
|
||||||
|
self.all_model_classes = self.model_tester.all_model_classes
|
||||||
|
if self.pipeline_model_mapping is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You have inherited from CausalLMModelTest but did not set the pipeline_model_mapping attribute."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_sequence_classification_model(self):
|
||||||
|
if self.model_tester.sequence_classification_class is None:
|
||||||
|
self.skipTest("Model does not support sequence classification")
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||||
|
model = self.model_tester.sequence_classification_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
def test_sequence_classification_model_for_single_label(self):
|
||||||
|
if self.model_tester.sequence_classification_class is None:
|
||||||
|
self.skipTest("Model does not support sequence classification")
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
config.problem_type = "single_label_classification"
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||||
|
model = self.model_tester.sequence_classification_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
def test_sequence_classification_model_for_multi_label(self):
|
||||||
|
if self.model_tester.sequence_classification_class is None:
|
||||||
|
self.skipTest("Model does not support sequence classification")
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
config.problem_type = "multi_label_classification"
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
sequence_labels = ids_tensor(
|
||||||
|
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||||
|
).to(torch.float)
|
||||||
|
model = self.model_tester.sequence_classification_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
|
def test_token_classification_model(self):
|
||||||
|
if self.model_tester.token_classification_class is None:
|
||||||
|
self.skipTest("Model does not support token classification")
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.num_labels = 3
|
||||||
|
input_ids = input_dict["input_ids"]
|
||||||
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
|
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||||
|
model = self.model_tester.token_classification_class(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||||
|
self.assertEqual(
|
||||||
|
result.logits.shape,
|
||||||
|
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
|
||||||
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||||
|
if self.rotary_embedding_layer is None:
|
||||||
|
self.skipTest("Rotary embedding layer not set")
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
|
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
original_model = self.model_tester_class.base_model_class(config)
|
||||||
|
original_model.to(torch_device)
|
||||||
|
original_model.eval()
|
||||||
|
original_short_output = original_model(short_input).last_hidden_state
|
||||||
|
original_long_output = original_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||||
|
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||||
|
scaled_model = self.model_tester_class.base_model_class(config)
|
||||||
|
scaled_model.to(torch_device)
|
||||||
|
scaled_model.eval()
|
||||||
|
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||||
|
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||||
|
|
||||||
|
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||||
|
# maximum sequence length, so the outputs for the short input should match.
|
||||||
|
if scaling_type == "dynamic":
|
||||||
|
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
||||||
|
else:
|
||||||
|
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||||
|
|
||||||
|
# The output should be different for long inputs
|
||||||
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|
||||||
|
def test_model_rope_scaling(self):
|
||||||
|
if self.rotary_embedding_layer is None:
|
||||||
|
self.skipTest("Rotary embedding layer not set")
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
scaling_factor = 10
|
||||||
|
short_input_length = 10
|
||||||
|
long_input_length = int(config.max_position_embeddings * 1.5)
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
x = torch.randn(
|
||||||
|
1, dtype=torch.float32, device=torch_device
|
||||||
|
) # used exclusively to get the dtype and the device
|
||||||
|
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||||
|
position_ids_short = position_ids_short.unsqueeze(0)
|
||||||
|
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||||
|
position_ids_long = position_ids_long.unsqueeze(0)
|
||||||
|
|
||||||
|
# Sanity check original RoPE
|
||||||
|
original_rope = self.rotary_embedding_layer(config=config).to(torch_device)
|
||||||
|
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||||
|
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||||
|
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||||
|
|
||||||
|
# Sanity check linear RoPE scaling
|
||||||
|
# New position "x" should match original position with index "x/scaling_factor"
|
||||||
|
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
||||||
|
linear_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device)
|
||||||
|
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||||
|
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||||
|
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||||
|
for new_position in range(0, long_input_length, scaling_factor):
|
||||||
|
original_position = int(new_position // scaling_factor)
|
||||||
|
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||||
|
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||||
|
|
||||||
|
# Sanity check Dynamic NTK RoPE scaling
|
||||||
|
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||||
|
# with scaling_factor (or that `inv_freq` decreases)
|
||||||
|
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
||||||
|
ntk_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device)
|
||||||
|
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||||
|
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||||
|
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||||
|
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||||
|
|
||||||
|
# Sanity check Yarn RoPE scaling
|
||||||
|
# Scaling should be over the entire input
|
||||||
|
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
||||||
|
yarn_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device)
|
||||||
|
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
||||||
|
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
||||||
|
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
||||||
|
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
|
@require_torch_accelerator
|
||||||
|
@slow
|
||||||
|
def test_sdpa_equivalence(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class._supports_sdpa:
|
||||||
|
self.skipTest(reason="Model does not support SDPA")
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_sdpa = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||||
|
)
|
||||||
|
model_sdpa.to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
dummy_input = dummy_input.to(torch_device)
|
||||||
|
outputs = model(dummy_input, output_hidden_states=True)
|
||||||
|
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
logits = outputs.hidden_states[-1]
|
||||||
|
logits_sdpa = outputs_sdpa.hidden_states[-1]
|
||||||
|
|
||||||
|
assert torch.allclose(logits_sdpa, logits, atol=2e-3)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@is_flaky()
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_equivalence(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
self.skipTest(reason="Model does not support Flash Attention 2")
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_fa = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||||
|
)
|
||||||
|
model_fa.to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
dummy_input = dummy_input.to(torch_device)
|
||||||
|
outputs = model(dummy_input, output_hidden_states=True)
|
||||||
|
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
logits = outputs.hidden_states[-1]
|
||||||
|
logits_fa = outputs_fa.hidden_states[-1]
|
||||||
|
|
||||||
|
assert torch.allclose(logits_fa, logits, atol=2e-3)
|
@ -179,7 +179,6 @@ class DbrxModelTester:
|
|||||||
)
|
)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Dbrx
|
|
||||||
def create_and_check_model(
|
def create_and_check_model(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@ -190,7 +189,6 @@ class DbrxModelTester:
|
|||||||
result = model(input_ids)
|
result = model(input_ids)
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Dbrx
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
|
@ -15,14 +15,11 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
FalconConfig,
|
FalconConfig,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
set_seed,
|
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@ -32,10 +29,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -48,126 +42,24 @@ if is_torch_available():
|
|||||||
FalconForTokenClassification,
|
FalconForTokenClassification,
|
||||||
FalconModel,
|
FalconModel,
|
||||||
)
|
)
|
||||||
from transformers.models.falcon.modeling_falcon import (
|
|
||||||
FalconRotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FalconModelTester:
|
class FalconModelTester(CausalLMModelTester):
|
||||||
def __init__(
|
if is_torch_available():
|
||||||
self,
|
config_class = FalconConfig
|
||||||
parent,
|
base_model_class = FalconModel
|
||||||
batch_size=3,
|
causal_lm_class = FalconForCausalLM
|
||||||
seq_length=7,
|
sequence_class = FalconForSequenceClassification
|
||||||
is_training=True,
|
token_class = FalconForTokenClassification
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def __init__(self, parent, new_decoder_architecture=True):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
super().__init__(parent)
|
||||||
|
self.new_decoder_architecture = new_decoder_architecture
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return FalconConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=1,
|
|
||||||
new_decoder_architecture=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = FalconModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class FalconModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
|
model_tester_class = FalconModelTester
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
FalconModel,
|
FalconModel,
|
||||||
@ -182,10 +74,9 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": FalconModel,
|
"feature-extraction": FalconModel,
|
||||||
"question-answering": FalconForQuestionAnswering,
|
|
||||||
"text-classification": FalconForSequenceClassification,
|
"text-classification": FalconForSequenceClassification,
|
||||||
"text-generation": FalconForCausalLM,
|
|
||||||
"token-classification": FalconForTokenClassification,
|
"token-classification": FalconForTokenClassification,
|
||||||
|
"text-generation": FalconForCausalLM,
|
||||||
"zero-shot": FalconForSequenceClassification,
|
"zero-shot": FalconForSequenceClassification,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -207,146 +98,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = FalconModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=FalconConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_position_embedding_types(self):
|
|
||||||
config, *inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for alibi in [True, False]:
|
|
||||||
config.alibi = alibi
|
|
||||||
self.model_tester.create_and_check_model(config, *inputs)
|
|
||||||
|
|
||||||
def test_falcon_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = FalconForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_falcon_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = FalconForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_falcon_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = FalconForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon
|
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
||||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
original_model = FalconModel(config)
|
|
||||||
original_model.to(torch_device)
|
|
||||||
original_model.eval()
|
|
||||||
original_short_output = original_model(short_input).last_hidden_state
|
|
||||||
original_long_output = original_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
|
||||||
scaled_model = FalconModel(config)
|
|
||||||
scaled_model.to(torch_device)
|
|
||||||
scaled_model.eval()
|
|
||||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
||||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
|
||||||
# maximum sequence length, so the outputs for the short input should match.
|
|
||||||
if scaling_type == "dynamic":
|
|
||||||
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
|
||||||
else:
|
|
||||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
||||||
|
|
||||||
# The output should be different for long inputs
|
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
||||||
|
|
||||||
# Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Falcon
|
|
||||||
def test_model_rope_scaling(self):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
scaling_factor = 10
|
|
||||||
short_input_length = 10
|
|
||||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
x = torch.randn(
|
|
||||||
1, dtype=torch.float32, device=torch_device
|
|
||||||
) # used exclusively to get the dtype and the device
|
|
||||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_short = position_ids_short.unsqueeze(0)
|
|
||||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_long = position_ids_long.unsqueeze(0)
|
|
||||||
|
|
||||||
# Sanity check original RoPE
|
|
||||||
original_rope = FalconRotaryEmbedding(config).to(torch_device)
|
|
||||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
|
||||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
|
||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
|
||||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
|
||||||
linear_scaling_rope = FalconRotaryEmbedding(config).to(torch_device)
|
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
|
||||||
for new_position in range(0, long_input_length, scaling_factor):
|
|
||||||
original_position = int(new_position // scaling_factor)
|
|
||||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
|
||||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
|
||||||
|
|
||||||
# Sanity check Dynamic NTK RoPE scaling
|
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
|
||||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
|
||||||
ntk_scaling_rope = FalconRotaryEmbedding(config).to(torch_device)
|
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
|
||||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
|
||||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class FalconLanguageGenerationTest(unittest.TestCase):
|
class FalconLanguageGenerationTest(unittest.TestCase):
|
||||||
|
@ -33,10 +33,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -51,138 +48,17 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GemmaModelTester:
|
class GemmaModelTester(CausalLMModelTester):
|
||||||
config_class = GemmaConfig
|
config_class = GemmaConfig
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
model_class = GemmaModel
|
base_model_class = GemmaModel
|
||||||
for_causal_lm_class = GemmaForCausalLM
|
causal_lm_class = GemmaForCausalLM
|
||||||
for_sequence_class = GemmaForSequenceClassification
|
sequence_classification_class = GemmaForSequenceClassification
|
||||||
for_token_class = GemmaForTokenClassification
|
token_classification_class = GemmaForTokenClassification
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
|
||||||
|
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return self.config_class(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
head_dim=self.head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = self.model_class(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Gemma
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class GemmaModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification)
|
(GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -199,12 +75,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
test_headmasking = False
|
model_tester_class = GemmaModelTester
|
||||||
test_pruning = False
|
|
||||||
|
|
||||||
# Need to remove 0.9 in `test_cpu_offload`
|
|
||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
|
||||||
model_split_percents = [0.5, 0.6]
|
|
||||||
|
|
||||||
# used in `test_torch_compile_for_training`
|
# used in `test_torch_compile_for_training`
|
||||||
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None
|
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None
|
||||||
@ -222,78 +93,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = GemmaModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=GemmaConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_Gemma_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = self.model_tester.for_sequence_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Gemma_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = self.model_tester.for_sequence_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Gemma_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = self.model_tester.for_sequence_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Gemma_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = self.model_tester.for_token_class(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@ -301,46 +100,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
self.skipTest(reason="Gemma flash attention does not support right padding")
|
self.skipTest(reason="Gemma flash attention does not support right padding")
|
||||||
|
|
||||||
@require_torch_sdpa
|
|
||||||
@require_torch_accelerator
|
|
||||||
def test_sdpa_equivalence(self):
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
if not model_class._supports_sdpa:
|
|
||||||
self.skipTest(reason="Model does not support SDPA")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config).to(torch_device)
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name].to(torch_device)
|
|
||||||
|
|
||||||
model.config._attn_implementation = "sdpa"
|
|
||||||
states_sdpa = model(dummy_input, output_hidden_states=True).hidden_states[-1]
|
|
||||||
|
|
||||||
model.config._attn_implementation = "eager"
|
|
||||||
states_eager = model(dummy_input, output_hidden_states=True).hidden_states[-1]
|
|
||||||
|
|
||||||
torch.testing.assert_close(states_sdpa, states_eager, atol=1e-5, rtol=1e-5)
|
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
def test_flash_attn_2_equivalence(self):
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config).to(device=torch_device, dtype=torch.float16)
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name].to(torch_device)
|
|
||||||
|
|
||||||
model.config._attn_implementation = "flash_attention_2"
|
|
||||||
states_sdpa = model(dummy_input, output_hidden_states=True).hidden_states[1]
|
|
||||||
|
|
||||||
model.config._attn_implementation = "eager"
|
|
||||||
states_eager = model(dummy_input, output_hidden_states=True).hidden_states[1]
|
|
||||||
|
|
||||||
# Here we use higher tolerance and the output of the 2nd layer because otherwise small diffs add-up
|
|
||||||
torch.testing.assert_close(states_sdpa, states_eager, atol=1e-3, rtol=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
|
@ -33,7 +33,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
|
|
||||||
|
|
||||||
@ -48,17 +48,28 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Gemma2ModelTester(GemmaModelTester):
|
class Gemma2ModelTester(CausalLMModelTester):
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
config_class = Gemma2Config
|
config_class = Gemma2Config
|
||||||
model_class = Gemma2Model
|
base_model_class = Gemma2Model
|
||||||
for_causal_lm_class = Gemma2ForCausalLM
|
causal_lm_class = Gemma2ForCausalLM
|
||||||
for_sequence_class = Gemma2ForSequenceClassification
|
sequence_class = Gemma2ForSequenceClassification
|
||||||
for_token_class = Gemma2ForTokenClassification
|
token_class = Gemma2ForTokenClassification
|
||||||
|
pipeline_model_mapping = (
|
||||||
|
{
|
||||||
|
"feature-extraction": Gemma2Model,
|
||||||
|
"text-classification": Gemma2ForSequenceClassification,
|
||||||
|
"token-classification": Gemma2ForTokenClassification,
|
||||||
|
"text-generation": Gemma2ForCausalLM,
|
||||||
|
"zero-shot": Gemma2ForSequenceClassification,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
class Gemma2ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Gemma2Model, Gemma2ForCausalLM, Gemma2ForSequenceClassification, Gemma2ForTokenClassification)
|
(Gemma2Model, Gemma2ForCausalLM, Gemma2ForSequenceClassification, Gemma2ForTokenClassification)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -75,10 +86,12 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
_is_stateful = True
|
_is_stateful = True
|
||||||
model_split_percents = [0.5, 0.6]
|
model_split_percents = [0.5, 0.6]
|
||||||
|
model_tester_class = Gemma2ModelTester
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Gemma2ModelTester(self)
|
self.model_tester = Gemma2ModelTester(self)
|
||||||
|
@ -19,7 +19,6 @@ import pytest
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GlmConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GlmConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_large_accelerator,
|
require_torch_large_accelerator,
|
||||||
@ -28,10 +27,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -46,133 +42,17 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GlmModelTester:
|
class GlmModelTester(CausalLMModelTester):
|
||||||
config_class = GlmConfig
|
config_class = GlmConfig
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
model_class = GlmModel
|
base_model_class = GlmModel
|
||||||
for_causal_lm_class = GlmForCausalLM
|
causal_lm_class = GlmForCausalLM
|
||||||
for_sequence_class = GlmForSequenceClassification
|
sequence_class = GlmForSequenceClassification
|
||||||
for_token_class = GlmForTokenClassification
|
token_class = GlmForTokenClassification
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="silu",
|
|
||||||
attention_dropout=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
|
||||||
|
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return self.config_class(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
attention_dropout=self.attention_dropout,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
head_dim=self.head_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = self.model_class(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Glm
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class GlmModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(GlmModel, GlmForCausalLM, GlmForSequenceClassification, GlmForTokenClassification)
|
(GlmModel, GlmForCausalLM, GlmForSequenceClassification, GlmForTokenClassification)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -188,120 +68,10 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
model_tester_class = GlmModelTester
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = GlmModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=GlmConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_Glm_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
print(config)
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = self.model_tester.for_sequence_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Glm_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = self.model_tester.for_sequence_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Glm_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = self.model_tester.for_sequence_class(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Glm_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = self.model_tester.for_token_class(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@is_flaky()
|
|
||||||
def test_custom_4d_attention_mask(self):
|
|
||||||
"""Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky."""
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
if not model_class._supports_static_cache:
|
|
||||||
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
|
|
||||||
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
|
|
||||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
|
||||||
|
|
||||||
(
|
|
||||||
input_ids,
|
|
||||||
position_ids,
|
|
||||||
input_ids_shared_prefix,
|
|
||||||
mask_shared_prefix,
|
|
||||||
position_ids_shared_prefix,
|
|
||||||
) = self._get_custom_4d_mask_test_data()
|
|
||||||
|
|
||||||
logits = model.forward(input_ids, position_ids=position_ids).logits
|
|
||||||
# logits.shape == torch.Size([3, 4, ...])
|
|
||||||
|
|
||||||
logits_shared_prefix = model(
|
|
||||||
input_ids_shared_prefix,
|
|
||||||
attention_mask=mask_shared_prefix,
|
|
||||||
position_ids=position_ids_shared_prefix,
|
|
||||||
)[0]
|
|
||||||
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
|
||||||
|
|
||||||
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
|
||||||
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
|
||||||
|
|
||||||
# comparing softmax-normalized logits:
|
|
||||||
normalized_0 = torch.nn.functional.softmax(out_last_tokens)
|
|
||||||
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
|
|
||||||
print(torch.abs(normalized_0 - normalized_1).max())
|
|
||||||
|
|
||||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@ -28,8 +28,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -43,17 +42,18 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Glm4ModelTester(GemmaModelTester):
|
class Glm4ModelTester(CausalLMModelTester):
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
config_class = Glm4Config
|
config_class = Glm4Config
|
||||||
model_class = Glm4Model
|
base_model_class = Glm4Model
|
||||||
for_causal_lm_class = Glm4ForCausalLM
|
causal_lm_class = Glm4ForCausalLM
|
||||||
for_sequence_class = Glm4ForSequenceClassification
|
sequence_classification_class = Glm4ForSequenceClassification
|
||||||
for_token_class = Glm4ForTokenClassification
|
token_classification_class = Glm4ForTokenClassification
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Glm4ModelTest(GemmaModelTest, unittest.TestCase):
|
class Glm4ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
|
model_tester_class = Glm4ModelTester
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Glm4Model, Glm4ForCausalLM, Glm4ForSequenceClassification, Glm4ForTokenClassification)
|
(Glm4Model, Glm4ForCausalLM, Glm4ForSequenceClassification, Glm4ForTokenClassification)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -75,10 +75,6 @@ class Glm4ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
_is_stateful = True
|
_is_stateful = True
|
||||||
model_split_percents = [0.5, 0.6]
|
model_split_percents = [0.5, 0.6]
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = Glm4ModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=Glm4Config, hidden_size=37)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_large_gpu
|
@require_torch_large_gpu
|
||||||
|
@ -341,7 +341,6 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
@parameterized.expand([("linear",), ("dynamic",)])
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->GPTNeoX
|
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||||
|
@ -28,10 +28,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -44,7 +41,14 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class JetMoeModelTester:
|
class JetMoeModelTester(CausalLMModelTester):
|
||||||
|
config_class = JetMoeConfig
|
||||||
|
forced_config_args = ["pad_token_id"]
|
||||||
|
if is_torch_available():
|
||||||
|
base_model_class = JetMoeModel
|
||||||
|
causal_lm_class = JetMoeForCausalLM
|
||||||
|
sequence_class = JetMoeForSequenceClassification
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
@ -72,6 +76,7 @@ class JetMoeModelTester:
|
|||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
scope=None,
|
scope=None,
|
||||||
):
|
):
|
||||||
|
super().__init__(parent)
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
@ -98,159 +103,29 @@ class JetMoeModelTester:
|
|||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.ones(self.batch_size, self.seq_length).to(torch_device)
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return JetMoeConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
kv_channels=self.kv_channels,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
activation_function=self.hidden_act,
|
|
||||||
num_local_experts=self.num_local_experts,
|
|
||||||
num_experts_per_tok=self.num_experts_per_tok,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = JetMoeModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class JetMoeModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(JetMoeModel, JetMoeForCausalLM, JetMoeForSequenceClassification) if is_torch_available() else ()
|
(JetMoeModel, JetMoeForCausalLM, JetMoeForSequenceClassification) if is_torch_available() else ()
|
||||||
)
|
)
|
||||||
pipeline_model_mapping = (
|
|
||||||
{
|
|
||||||
"feature-extraction": JetMoeModel,
|
|
||||||
"text-classification": JetMoeForSequenceClassification,
|
|
||||||
"text-generation": JetMoeForCausalLM,
|
|
||||||
"zero-shot": JetMoeForSequenceClassification,
|
|
||||||
}
|
|
||||||
if is_torch_available()
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_mismatched_shapes = False
|
test_mismatched_shapes = False
|
||||||
test_cpu_offload = False
|
test_cpu_offload = False
|
||||||
test_disk_offload_bin = False
|
test_disk_offload_bin = False
|
||||||
test_disk_offload_safetensors = False
|
test_disk_offload_safetensors = False
|
||||||
|
model_tester_class = JetMoeModelTester
|
||||||
def setUp(self):
|
pipeline_model_mapping = (
|
||||||
self.model_tester = JetMoeModelTester(self)
|
{
|
||||||
self.config_tester = ConfigTester(
|
"feature-extraction": JetMoeModel,
|
||||||
self, config_class=JetMoeConfig, common_properties=["hidden_size", "num_hidden_layers"]
|
"text-classification": JetMoeForSequenceClassification,
|
||||||
|
"text-generation": JetMoeForCausalLM,
|
||||||
|
}
|
||||||
|
if is_torch_available()
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with llama->jetmoe, Llama->JetMoe
|
|
||||||
def test_jetmoe_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = JetMoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with llama->jetmoe, Llama->JetMoe
|
|
||||||
def test_jetmoe_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = JetMoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with llama->jetmoe, Llama->JetMoe
|
|
||||||
def test_jetmoe_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = JetMoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -16,9 +16,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed
|
from transformers import AutoTokenizer, StaticCache, is_torch_available
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
Expectations,
|
Expectations,
|
||||||
@ -30,16 +29,14 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
LlamaConfig,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaForQuestionAnswering,
|
LlamaForQuestionAnswering,
|
||||||
LlamaForSequenceClassification,
|
LlamaForSequenceClassification,
|
||||||
@ -50,124 +47,17 @@ if is_torch_available():
|
|||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
class LlamaModelTester:
|
class LlamaModelTester(CausalLMModelTester):
|
||||||
def __init__(
|
if is_torch_available():
|
||||||
self,
|
config_class = LlamaConfig
|
||||||
parent,
|
base_model_class = LlamaModel
|
||||||
batch_size=13,
|
causal_lm_class = LlamaForCausalLM
|
||||||
seq_length=7,
|
sequence_class = LlamaForSequenceClassification
|
||||||
is_training=True,
|
token_class = LlamaForTokenClassification
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return LlamaConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = LlamaModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class LlamaModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
@ -194,6 +84,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
||||||
|
model_tester_class = LlamaModelTester
|
||||||
|
rotary_embedding_layer = LlamaRotaryEmbedding # Enables RoPE tests if set
|
||||||
|
|
||||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
@ -202,230 +94,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
# used in `test_torch_compile_for_training`
|
# used in `test_torch_compile_for_training`
|
||||||
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
|
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = LlamaModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_llama_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = LlamaForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_llama_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = LlamaForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_llama_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = LlamaForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_llama_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = LlamaForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",), ("yarn",)])
|
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
||||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
original_model = LlamaModel(config)
|
|
||||||
original_model.to(torch_device)
|
|
||||||
original_model.eval()
|
|
||||||
original_short_output = original_model(short_input).last_hidden_state
|
|
||||||
original_long_output = original_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
|
||||||
scaled_model = LlamaModel(config)
|
|
||||||
scaled_model.to(torch_device)
|
|
||||||
scaled_model.eval()
|
|
||||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
||||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
|
||||||
# maximum sequence length, so the outputs for the short input should match.
|
|
||||||
if scaling_type == "dynamic":
|
|
||||||
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
|
||||||
else:
|
|
||||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
||||||
|
|
||||||
# The output should be different for long inputs
|
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
||||||
|
|
||||||
def test_model_rope_scaling(self):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
scaling_factor = 10
|
|
||||||
short_input_length = 10
|
|
||||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
x = torch.randn(
|
|
||||||
1, dtype=torch.float32, device=torch_device
|
|
||||||
) # used exclusively to get the dtype and the device
|
|
||||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_short = position_ids_short.unsqueeze(0)
|
|
||||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_long = position_ids_long.unsqueeze(0)
|
|
||||||
|
|
||||||
# Sanity check original RoPE
|
|
||||||
original_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
|
||||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
|
||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
|
||||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
|
||||||
linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
|
||||||
for new_position in range(0, long_input_length, scaling_factor):
|
|
||||||
original_position = int(new_position // scaling_factor)
|
|
||||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
|
||||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
|
||||||
|
|
||||||
# Sanity check Dynamic NTK RoPE scaling
|
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
|
||||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
|
||||||
ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
|
||||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
|
||||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
|
||||||
|
|
||||||
# Sanity check Yarn RoPE scaling
|
|
||||||
# Scaling should be over the entire input
|
|
||||||
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
|
||||||
yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device)
|
|
||||||
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
|
||||||
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
|
||||||
|
|
||||||
def test_model_loading_old_rope_configs(self):
|
|
||||||
def _reinitialize_config(base_config, new_kwargs):
|
|
||||||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
|
||||||
# steps.
|
|
||||||
base_config_dict = base_config.to_dict()
|
|
||||||
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
|
|
||||||
return new_config
|
|
||||||
|
|
||||||
# from untouched config -> ✅
|
|
||||||
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
original_model = LlamaForCausalLM(base_config).to(torch_device)
|
|
||||||
original_model(**model_inputs)
|
|
||||||
|
|
||||||
# from a config with the expected rope configuration -> ✅
|
|
||||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
|
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
|
||||||
original_model(**model_inputs)
|
|
||||||
|
|
||||||
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
|
|
||||||
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
|
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
|
||||||
original_model(**model_inputs)
|
|
||||||
|
|
||||||
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
|
|
||||||
config = _reinitialize_config(
|
|
||||||
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
|
|
||||||
)
|
|
||||||
self.assertTrue(config.rope_scaling["type"] == "linear")
|
|
||||||
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
|
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
|
||||||
original_model(**model_inputs)
|
|
||||||
|
|
||||||
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
|
|
||||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
|
||||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
|
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
|
||||||
original_model(**model_inputs)
|
|
||||||
self.assertEqual(len(logs.output), 1)
|
|
||||||
self.assertIn("factor field", logs.output[0])
|
|
||||||
|
|
||||||
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
|
|
||||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
|
||||||
config = _reinitialize_config(
|
|
||||||
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
|
|
||||||
)
|
|
||||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
|
||||||
original_model(**model_inputs)
|
|
||||||
self.assertEqual(len(logs.output), 1)
|
|
||||||
self.assertIn("Unrecognized keys", logs.output[0])
|
|
||||||
|
|
||||||
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class LlamaIntegrationTest(unittest.TestCase):
|
class LlamaIntegrationTest(unittest.TestCase):
|
||||||
|
@ -34,11 +34,6 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -51,131 +46,21 @@ if is_torch_available():
|
|||||||
MistralModel,
|
MistralModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
|
||||||
class MistralModelTester:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
class MistralModelTester(CausalLMModelTester):
|
||||||
def prepare_config_and_inputs(self):
|
config_class = MistralConfig
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
if is_torch_available():
|
||||||
|
base_model_class = MistralModel
|
||||||
input_mask = None
|
causal_lm_class = MistralForCausalLM
|
||||||
if self.use_input_mask:
|
sequence_class = MistralForSequenceClassification
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
token_class = MistralForTokenClassification
|
||||||
|
question_answering_class = MistralForQuestionAnswering
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return MistralConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mistral
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = MistralModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class MistralModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
MistralModel,
|
MistralModel,
|
||||||
@ -193,7 +78,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
"text-classification": MistralForSequenceClassification,
|
"text-classification": MistralForSequenceClassification,
|
||||||
"token-classification": MistralForTokenClassification,
|
"token-classification": MistralForTokenClassification,
|
||||||
"text-generation": MistralForCausalLM,
|
"text-generation": MistralForCausalLM,
|
||||||
"zero-shot": MistralForSequenceClassification,
|
|
||||||
"question-answering": MistralForQuestionAnswering,
|
"question-answering": MistralForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -201,7 +85,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
)
|
)
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
model_tester_class = MistralModelTester
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
@ -216,82 +100,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = MistralModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=MistralConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_torch_fx_output_loss(self):
|
|
||||||
super().test_torch_fx_output_loss()
|
|
||||||
|
|
||||||
def test_Mistral_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = MistralForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Mistral_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = MistralForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Mistral_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = MistralForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral
|
|
||||||
def test_Mistral_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = MistralForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -27,11 +27,6 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -44,137 +39,21 @@ if is_torch_available():
|
|||||||
MixtralModel,
|
MixtralModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
|
||||||
class MixtralModelTester:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
router_jitter_noise=0.1,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
self.router_jitter_noise = router_jitter_noise
|
|
||||||
|
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
|
class MixtralModelTester(CausalLMModelTester):
|
||||||
def prepare_config_and_inputs(self):
|
config_class = MixtralConfig
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
if is_torch_available():
|
||||||
|
base_model_class = MixtralModel
|
||||||
input_mask = None
|
causal_lm_class = MixtralForCausalLM
|
||||||
if self.use_input_mask:
|
sequence_class = MixtralForSequenceClassification
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
token_class = MixtralForTokenClassification
|
||||||
|
question_answering_class = MixtralForQuestionAnswering
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return MixtralConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
num_experts_per_tok=2,
|
|
||||||
num_local_experts=2,
|
|
||||||
router_jitter_noise=self.router_jitter_noise,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = MixtralModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Mixtral
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral
|
class MistralModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
MixtralModel,
|
MixtralModel,
|
||||||
@ -192,15 +71,15 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
"text-classification": MixtralForSequenceClassification,
|
"text-classification": MixtralForSequenceClassification,
|
||||||
"token-classification": MixtralForTokenClassification,
|
"token-classification": MixtralForTokenClassification,
|
||||||
"text-generation": MixtralForCausalLM,
|
"text-generation": MixtralForCausalLM,
|
||||||
"zero-shot": MixtralForSequenceClassification,
|
|
||||||
"question-answering": MixtralForQuestionAnswering,
|
"question-answering": MixtralForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
model_tester_class = MixtralModelTester
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
@ -215,88 +94,12 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = MixtralModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=MixtralConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_torch_fx_output_loss(self):
|
|
||||||
super().test_torch_fx_output_loss()
|
|
||||||
|
|
||||||
def test_Mixtral_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = MixtralForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Mixtral_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = MixtralForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Mixtral_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = MixtralForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral
|
|
||||||
def test_Mixtral_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = MixtralForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
||||||
self.skipTest(reason="Mixtral flash attention does not support right padding")
|
self.skipTest(reason="Mistral flash attention does not support right padding")
|
||||||
|
|
||||||
# Ignore copy
|
# Ignore copy
|
||||||
def test_load_balancing_loss(self):
|
def test_load_balancing_loss(self):
|
||||||
|
@ -14,25 +14,19 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch Nemotron model."""
|
"""Testing suite for the PyTorch Nemotron model."""
|
||||||
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from transformers import NemotronConfig, is_torch_available
|
from transformers import NemotronConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
Expectations,
|
Expectations,
|
||||||
is_flaky,
|
|
||||||
require_flash_attn,
|
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
|
|
||||||
|
|
||||||
@ -49,17 +43,18 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NemotronModelTester(GemmaModelTester):
|
class NemotronModelTester(CausalLMModelTester):
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
config_class = NemotronConfig
|
config_class = NemotronConfig
|
||||||
model_class = NemotronModel
|
base_model_class = NemotronModel
|
||||||
for_causal_lm_class = NemotronForCausalLM
|
causal_lm_class = NemotronForCausalLM
|
||||||
for_sequence_class = NemotronForSequenceClassification
|
sequence_class = NemotronForSequenceClassification
|
||||||
for_token_class = NemotronForTokenClassification
|
token_class = NemotronForTokenClassification
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class NemotronModelTest(GemmaModelTest):
|
class NemotronModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
|
model_tester_class = NemotronModelTester
|
||||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
# This is because we are hitting edge cases with the causal_mask buffer
|
||||||
model_split_percents = [0.5, 0.7, 0.8]
|
model_split_percents = [0.5, 0.7, 0.8]
|
||||||
@ -101,40 +96,6 @@ class NemotronModelTest(GemmaModelTest):
|
|||||||
def test_model_outputs_equivalence(self, **kwargs):
|
def test_model_outputs_equivalence(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@require_flash_attn
|
|
||||||
@require_torch_gpu
|
|
||||||
@pytest.mark.flash_attn_test
|
|
||||||
@is_flaky()
|
|
||||||
@slow
|
|
||||||
def test_flash_attn_2_equivalence(self):
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
if not model_class._supports_flash_attn_2:
|
|
||||||
self.skipTest(reason="Model does not support Flash Attention 2")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model_fa = model_class.from_pretrained(
|
|
||||||
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
|
||||||
)
|
|
||||||
model_fa.to(torch_device)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
dummy_input = dummy_input.to(torch_device)
|
|
||||||
outputs = model(dummy_input, output_hidden_states=True)
|
|
||||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
|
||||||
|
|
||||||
logits = outputs.hidden_states[-1]
|
|
||||||
logits_fa = outputs_fa.hidden_states[-1]
|
|
||||||
|
|
||||||
# nemotron flash attention 2 needs a high tolerance
|
|
||||||
assert torch.allclose(logits_fa, logits, atol=1e-2)
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class NemotronIntegrationTest(unittest.TestCase):
|
class NemotronIntegrationTest(unittest.TestCase):
|
||||||
|
@ -16,9 +16,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
from transformers import PersimmonConfig, is_torch_available
|
||||||
|
|
||||||
from transformers import PersimmonConfig, is_torch_available, set_seed
|
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@ -29,11 +27,6 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -45,128 +38,22 @@ if is_torch_available():
|
|||||||
PersimmonForTokenClassification,
|
PersimmonForTokenClassification,
|
||||||
PersimmonModel,
|
PersimmonModel,
|
||||||
)
|
)
|
||||||
from transformers.models.persimmon.modeling_persimmon import PersimmonRotaryEmbedding
|
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
|
class PersimmonModelTester(CausalLMModelTester):
|
||||||
class PersimmonModelTester:
|
if is_torch_available():
|
||||||
def __init__(
|
config_class = PersimmonConfig
|
||||||
self,
|
base_model_class = PersimmonModel
|
||||||
parent,
|
causal_lm_class = PersimmonForCausalLM
|
||||||
batch_size=13,
|
sequence_class = PersimmonForSequenceClassification
|
||||||
seq_length=7,
|
token_class = PersimmonForTokenClassification
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return PersimmonConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = PersimmonModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class PersimmonModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
|
model_tester_class = PersimmonModelTester
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification)
|
(PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -184,173 +71,11 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
model_tester_class = PersimmonModelTester
|
||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Persimmon
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = PersimmonModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=PersimmonConfig, hidden_size=37)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Persimmon,llama->persimmon
|
|
||||||
def test_persimmon_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = PersimmonForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Persimmon,llama->persimmon
|
|
||||||
def test_persimmon_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = PersimmonForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Persimmon,llama->persimmon
|
|
||||||
def test_persimmon_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = PersimmonForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon
|
|
||||||
def test_persimmon_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = PersimmonForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon
|
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
||||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
original_model = PersimmonModel(config)
|
|
||||||
original_model.to(torch_device)
|
|
||||||
original_model.eval()
|
|
||||||
original_short_output = original_model(short_input).last_hidden_state
|
|
||||||
original_long_output = original_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
|
||||||
scaled_model = PersimmonModel(config)
|
|
||||||
scaled_model.to(torch_device)
|
|
||||||
scaled_model.eval()
|
|
||||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
||||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
|
||||||
# maximum sequence length, so the outputs for the short input should match.
|
|
||||||
if scaling_type == "dynamic":
|
|
||||||
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
|
||||||
else:
|
|
||||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
||||||
|
|
||||||
# The output should be different for long inputs
|
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
||||||
|
|
||||||
# Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Persimmon
|
|
||||||
def test_model_rope_scaling(self):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
scaling_factor = 10
|
|
||||||
short_input_length = 10
|
|
||||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
x = torch.randn(
|
|
||||||
1, dtype=torch.float32, device=torch_device
|
|
||||||
) # used exclusively to get the dtype and the device
|
|
||||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_short = position_ids_short.unsqueeze(0)
|
|
||||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_long = position_ids_long.unsqueeze(0)
|
|
||||||
|
|
||||||
# Sanity check original RoPE
|
|
||||||
original_rope = PersimmonRotaryEmbedding(config).to(torch_device)
|
|
||||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
|
||||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
|
||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
|
||||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
|
||||||
linear_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device)
|
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
|
||||||
for new_position in range(0, long_input_length, scaling_factor):
|
|
||||||
original_position = int(new_position // scaling_factor)
|
|
||||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
|
||||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
|
||||||
|
|
||||||
# Sanity check Dynamic NTK RoPE scaling
|
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
|
||||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
|
||||||
ntk_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device)
|
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
|
||||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
|
||||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class PersimmonIntegrationTest(unittest.TestCase):
|
class PersimmonIntegrationTest(unittest.TestCase):
|
||||||
|
@ -16,19 +16,14 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
from transformers import PhiConfig, is_torch_available
|
||||||
|
|
||||||
from transformers import PhiConfig, is_torch_available, set_seed
|
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -44,124 +39,17 @@ if is_torch_available():
|
|||||||
from transformers.models.phi.modeling_phi import PhiRotaryEmbedding
|
from transformers.models.phi.modeling_phi import PhiRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
class PhiModelTester:
|
class PhiModelTester(CausalLMModelTester):
|
||||||
def __init__(
|
config_class = PhiConfig
|
||||||
self,
|
if is_torch_available():
|
||||||
parent,
|
base_model_class = PhiModel
|
||||||
batch_size=13,
|
causal_lm_class = PhiForCausalLM
|
||||||
seq_length=7,
|
sequence_class = PhiForSequenceClassification
|
||||||
is_training=True,
|
token_class = PhiForTokenClassification
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return PhiConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = PhiModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class PhiModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(PhiModel, PhiForCausalLM, PhiForSequenceClassification, PhiForTokenClassification)
|
(PhiModel, PhiForCausalLM, PhiForSequenceClassification, PhiForTokenClassification)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -171,9 +59,8 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
{
|
{
|
||||||
"feature-extraction": PhiModel,
|
"feature-extraction": PhiModel,
|
||||||
"text-classification": PhiForSequenceClassification,
|
"text-classification": PhiForSequenceClassification,
|
||||||
"text-generation": PhiForCausalLM,
|
|
||||||
"token-classification": PhiForTokenClassification,
|
"token-classification": PhiForTokenClassification,
|
||||||
"zero-shot": PhiForSequenceClassification,
|
"text-generation": PhiForCausalLM,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
@ -181,6 +68,8 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
model_tester_class = PhiModelTester
|
||||||
|
rotary_embedding_layer = PhiRotaryEmbedding
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
@ -195,146 +84,6 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Phi
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = PhiModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=PhiConfig, hidden_size=37)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Phi,llama->phi
|
|
||||||
def test_phi_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = PhiForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Phi,llama->phi
|
|
||||||
def test_phi_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = PhiForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Phi,llama->phi
|
|
||||||
def test_phi_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = PhiForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Phi
|
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
||||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
original_model = PhiModel(config)
|
|
||||||
original_model.to(torch_device)
|
|
||||||
original_model.eval()
|
|
||||||
original_short_output = original_model(short_input).last_hidden_state
|
|
||||||
original_long_output = original_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
|
||||||
scaled_model = PhiModel(config)
|
|
||||||
scaled_model.to(torch_device)
|
|
||||||
scaled_model.eval()
|
|
||||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
||||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
|
||||||
# maximum sequence length, so the outputs for the short input should match.
|
|
||||||
if scaling_type == "dynamic":
|
|
||||||
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
|
||||||
else:
|
|
||||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
||||||
|
|
||||||
# The output should be different for long inputs
|
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
||||||
|
|
||||||
# Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Phi
|
|
||||||
def test_model_rope_scaling(self):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
scaling_factor = 10
|
|
||||||
short_input_length = 10
|
|
||||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
x = torch.randn(
|
|
||||||
1, dtype=torch.float32, device=torch_device
|
|
||||||
) # used exclusively to get the dtype and the device
|
|
||||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_short = position_ids_short.unsqueeze(0)
|
|
||||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_long = position_ids_long.unsqueeze(0)
|
|
||||||
|
|
||||||
# Sanity check original RoPE
|
|
||||||
original_rope = PhiRotaryEmbedding(config).to(torch_device)
|
|
||||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
|
||||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
|
||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
|
||||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
|
||||||
linear_scaling_rope = PhiRotaryEmbedding(config).to(torch_device)
|
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
|
||||||
for new_position in range(0, long_input_length, scaling_factor):
|
|
||||||
original_position = int(new_position // scaling_factor)
|
|
||||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
|
||||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
|
||||||
|
|
||||||
# Sanity check Dynamic NTK RoPE scaling
|
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
|
||||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
|
||||||
ntk_scaling_rope = PhiRotaryEmbedding(config).to(torch_device)
|
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
|
||||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
|
||||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@ -16,9 +16,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
from transformers import Phi3Config, StaticCache, is_torch_available
|
||||||
|
|
||||||
from transformers import Phi3Config, StaticCache, is_torch_available, set_seed
|
|
||||||
from transformers.models.auto.configuration_auto import AutoConfig
|
from transformers.models.auto.configuration_auto import AutoConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_torch,
|
require_torch,
|
||||||
@ -26,10 +24,7 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -42,6 +37,7 @@ if is_torch_available():
|
|||||||
Phi3ForTokenClassification,
|
Phi3ForTokenClassification,
|
||||||
Phi3Model,
|
Phi3Model,
|
||||||
)
|
)
|
||||||
|
from transformers.models.phi3.modeling_phi3 import Phi3RotaryEmbedding
|
||||||
|
|
||||||
end_of_text_token = 32000
|
end_of_text_token = 32000
|
||||||
|
|
||||||
@ -93,127 +89,17 @@ if is_torch_available():
|
|||||||
return response_tokens
|
return response_tokens
|
||||||
|
|
||||||
|
|
||||||
class Phi3ModelTester:
|
class Phi3ModelTester(CausalLMModelTester):
|
||||||
def __init__(
|
config_class = Phi3Config
|
||||||
self,
|
if is_torch_available():
|
||||||
parent,
|
base_model_class = Phi3Model
|
||||||
batch_size=13,
|
causal_lm_class = Phi3ForCausalLM
|
||||||
seq_length=7,
|
sequence_class = Phi3ForSequenceClassification
|
||||||
is_training=True,
|
token_class = Phi3ForTokenClassification
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return Phi3Config(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Phi3
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = Phi3Model(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Phi3ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Phi3Model, Phi3ForCausalLM, Phi3ForSequenceClassification, Phi3ForTokenClassification)
|
(Phi3Model, Phi3ForCausalLM, Phi3ForSequenceClassification, Phi3ForTokenClassification)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
@ -223,9 +109,8 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
{
|
{
|
||||||
"feature-extraction": Phi3Model,
|
"feature-extraction": Phi3Model,
|
||||||
"text-classification": Phi3ForSequenceClassification,
|
"text-classification": Phi3ForSequenceClassification,
|
||||||
"text-generation": Phi3ForCausalLM,
|
|
||||||
"token-classification": Phi3ForTokenClassification,
|
"token-classification": Phi3ForTokenClassification,
|
||||||
"zero-shot": Phi3ForSequenceClassification,
|
"text-generation": Phi3ForCausalLM,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
@ -233,150 +118,8 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
model_tester_class = Phi3ModelTester
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
|
rotary_embedding_layer = Phi3RotaryEmbedding
|
||||||
def is_pipeline_test_to_skip(
|
|
||||||
self,
|
|
||||||
pipeline_test_case_name,
|
|
||||||
config_class,
|
|
||||||
model_architecture,
|
|
||||||
tokenizer_name,
|
|
||||||
image_processor_name,
|
|
||||||
feature_extractor_name,
|
|
||||||
processor_name,
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Phi3
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = Phi3ModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=Phi3Config, hidden_size=37)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Phi3,llama->phi3
|
|
||||||
def test_phi3_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Phi3ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Phi3,llama->phi3
|
|
||||||
def test_phi3_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Phi3ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Phi3,llama->phi3
|
|
||||||
def test_phi3_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = Phi3ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
@parameterized.expand([("longrope",)])
|
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
||||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
original_model = Phi3Model(config)
|
|
||||||
original_model.to(torch_device)
|
|
||||||
original_model.eval()
|
|
||||||
original_short_output = original_model(short_input).last_hidden_state
|
|
||||||
original_long_output = original_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
n_factors = config.hidden_size // config.num_attention_heads // 2
|
|
||||||
config.rope_scaling = {
|
|
||||||
"type": scaling_type,
|
|
||||||
"short_factor": [5.0 for _ in range(n_factors)],
|
|
||||||
"long_factor": [5.0 for _ in range(n_factors)],
|
|
||||||
}
|
|
||||||
scaled_model = Phi3Model(config)
|
|
||||||
scaled_model.to(torch_device)
|
|
||||||
scaled_model.eval()
|
|
||||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
||||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
# Scaling changes the RoPE embeddings, both for the short and long outputs
|
|
||||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
||||||
|
|
||||||
@parameterized.expand([("longrope",)])
|
|
||||||
def test_model_rope_scaling_short_long_factor(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
n_factors = config.hidden_size // config.num_key_value_heads // 2
|
|
||||||
config.rope_scaling = {
|
|
||||||
"type": scaling_type,
|
|
||||||
"short_factor": [3.0 for _ in range(n_factors)],
|
|
||||||
"long_factor": [5.0 for _ in range(n_factors)],
|
|
||||||
}
|
|
||||||
input_tensor = ids_tensor([1, 4090], config.vocab_size)
|
|
||||||
# Make sure we don't have padding tokens. If this is the case, then the actual number of "true" tokens may be shorter
|
|
||||||
# than `config.original_max_position_embeddings + 5`, invalidating this test
|
|
||||||
input_tensor[input_tensor == config.pad_token_id] += 1
|
|
||||||
model = Phi3ForCausalLM(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
generation_args_short = {
|
|
||||||
"max_length": config.original_max_position_embeddings,
|
|
||||||
"temperature": 0.0,
|
|
||||||
"use_cache": True,
|
|
||||||
"do_sample": False,
|
|
||||||
"return_dict_in_generate": True,
|
|
||||||
}
|
|
||||||
output_with_short_factor = model.generate(input_tensor, **generation_args_short)
|
|
||||||
keys_with_short_factor = output_with_short_factor.past_key_values[0][0]
|
|
||||||
generation_args_long = {
|
|
||||||
"max_length": config.original_max_position_embeddings + 5,
|
|
||||||
"temperature": 0.0,
|
|
||||||
"use_cache": True,
|
|
||||||
"do_sample": False,
|
|
||||||
"return_dict_in_generate": True,
|
|
||||||
"output_logits": True,
|
|
||||||
}
|
|
||||||
output_with_long_factor = model.generate(input_tensor, **generation_args_long)
|
|
||||||
keys_with_long_factor = output_with_long_factor.past_key_values[0][0]
|
|
||||||
last_token_logits = output_with_long_factor.logits[-1][-1]
|
|
||||||
regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1]
|
|
||||||
keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :]
|
|
||||||
|
|
||||||
# KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position
|
|
||||||
self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-2, rtol=1e-2))
|
|
||||||
# Last token generated using long factor
|
|
||||||
torch.testing.assert_close(last_token_logits, regenerated_last_token_logits, rtol=1e-2, atol=1e-2)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
|
@ -16,20 +16,14 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
from transformers import PhimoeConfig, StaticCache, is_torch_available
|
||||||
|
|
||||||
from transformers import PhimoeConfig, StaticCache, is_torch_available, set_seed
|
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
|
||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -92,138 +86,23 @@ if is_torch_available():
|
|||||||
return response_tokens
|
return response_tokens
|
||||||
|
|
||||||
|
|
||||||
class PhimoeModelTester:
|
class PhimoeModelTester(CausalLMModelTester):
|
||||||
def __init__(
|
if is_torch_available():
|
||||||
self,
|
config_class = PhimoeConfig
|
||||||
parent,
|
base_model_class = PhimoeModel
|
||||||
batch_size=13,
|
causal_lm_class = PhimoeForCausalLM
|
||||||
seq_length=7,
|
sequence_class = PhimoeForSequenceClassification
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=4,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=131072,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
original_max_position_embeddings=4096,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
self.original_max_position_embeddings = original_max_position_embeddings
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return PhimoeConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
num_experts_per_tok=2,
|
|
||||||
num_local_experts=2,
|
|
||||||
original_max_position_embeddings=self.original_max_position_embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Phimoe
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = PhimoeModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class PhimoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class PhimoeModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(PhimoeModel, PhimoeForCausalLM, PhimoeForSequenceClassification) if is_torch_available() else ()
|
(PhimoeModel, PhimoeForCausalLM, PhimoeForSequenceClassification) if is_torch_available() else ()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
model_tester_class = PhimoeModelTester
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": PhimoeModel,
|
"feature-extraction": PhimoeModel,
|
||||||
@ -235,150 +114,12 @@ class PhimoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_headmasking = False
|
|
||||||
test_pruning = False
|
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Phimoe
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = PhimoeModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=PhimoeConfig, hidden_size=37)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Phimoe,llama->phimoe
|
|
||||||
def test_phimoe_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = PhimoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Phimoe,llama->phimoe
|
|
||||||
def test_phimoe_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = PhimoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Phimoe,llama->phimoe
|
|
||||||
def test_phimoe_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = PhimoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
@parameterized.expand([("longrope",)])
|
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
||||||
long_input = ids_tensor([1, int(config.original_max_position_embeddings * 1.5)], config.vocab_size)
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
original_model = PhimoeModel(config)
|
|
||||||
original_model.to(torch_device)
|
|
||||||
original_model.eval()
|
|
||||||
original_short_output = original_model(short_input).last_hidden_state
|
|
||||||
original_long_output = original_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
n_factors = config.hidden_size // config.num_attention_heads // 2
|
|
||||||
config.rope_scaling = {
|
|
||||||
"type": scaling_type,
|
|
||||||
"short_factor": [3.0 for _ in range(n_factors)],
|
|
||||||
"long_factor": [5.0 for _ in range(n_factors)],
|
|
||||||
"short_mscale": 1.243163121016122,
|
|
||||||
"long_mscale": 1.243163121016122,
|
|
||||||
"original_max_position_embeddings": 4096,
|
|
||||||
}
|
|
||||||
scaled_model = PhimoeModel(config)
|
|
||||||
scaled_model.to(torch_device)
|
|
||||||
scaled_model.eval()
|
|
||||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
||||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
# Scaling changes the RoPE embeddings, both for the short and long outputs
|
|
||||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
||||||
|
|
||||||
@parameterized.expand([("longrope",)])
|
|
||||||
@is_flaky() # TODO (joao): unify rope tests in the mixin
|
|
||||||
def test_model_rope_scaling_short_long_factor(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
n_factors = config.hidden_size // config.num_key_value_heads // 2
|
|
||||||
config.rope_scaling = {
|
|
||||||
"type": scaling_type,
|
|
||||||
"short_factor": [3.0 for _ in range(n_factors)],
|
|
||||||
"long_factor": [5.0 for _ in range(n_factors)],
|
|
||||||
"short_mscale": 1.243163121016122,
|
|
||||||
"long_mscale": 1.243163121016122,
|
|
||||||
"original_max_position_embeddings": 4096,
|
|
||||||
}
|
|
||||||
input_tensor = ids_tensor([1, 4090], config.vocab_size)
|
|
||||||
model = PhimoeForCausalLM(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
generation_args_short = {
|
|
||||||
"max_length": config.original_max_position_embeddings,
|
|
||||||
"temperature": 0.0,
|
|
||||||
"use_cache": True,
|
|
||||||
"do_sample": False,
|
|
||||||
"return_dict_in_generate": True,
|
|
||||||
}
|
|
||||||
output_with_short_factor = model.generate(input_tensor, **generation_args_short)
|
|
||||||
keys_with_short_factor = output_with_short_factor.past_key_values[0][0]
|
|
||||||
generation_args_long = {
|
|
||||||
"max_length": config.original_max_position_embeddings + 5,
|
|
||||||
"temperature": 0.0,
|
|
||||||
"use_cache": True,
|
|
||||||
"do_sample": False,
|
|
||||||
"return_dict_in_generate": True,
|
|
||||||
"output_logits": True,
|
|
||||||
}
|
|
||||||
output_with_long_factor = model.generate(input_tensor, **generation_args_long)
|
|
||||||
keys_with_long_factor = output_with_long_factor.past_key_values[0][0]
|
|
||||||
last_token_logits = output_with_long_factor.logits[-1][-1]
|
|
||||||
regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1]
|
|
||||||
keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :]
|
|
||||||
|
|
||||||
# KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position
|
|
||||||
self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-3, rtol=1e-3))
|
|
||||||
# Last token generated using long factor
|
|
||||||
torch.testing.assert_close(last_token_logits, regenerated_last_token_logits, rtol=1e-2, atol=1e-2)
|
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@ -33,11 +33,6 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.import_utils import is_torch_greater_or_equal
|
from transformers.utils.import_utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -51,143 +46,21 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2ModelTester:
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=True,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=5,
|
|
||||||
max_window_layers=3,
|
|
||||||
use_sliding_window=True,
|
|
||||||
sliding_window=50,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=1,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.max_window_layers = max_window_layers
|
|
||||||
self.use_sliding_window = use_sliding_window
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
class Qwen2ModelTester(CausalLMModelTester):
|
||||||
if self.use_input_mask:
|
config_class = Qwen2Config
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
if is_torch_available():
|
||||||
|
base_model_class = Qwen2Model
|
||||||
token_type_ids = None
|
causal_lm_class = Qwen2ForCausalLM
|
||||||
if self.use_token_type_ids:
|
sequence_class = Qwen2ForSequenceClassification
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
token_class = Qwen2ForTokenClassification
|
||||||
|
question_answering_class = Qwen2ForQuestionAnswering
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return Qwen2Config(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
max_window_layers=self.max_window_layers,
|
|
||||||
use_sliding_window=self.use_sliding_window,
|
|
||||||
sliding_window=self.sliding_window,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
bos_token_id=self.bos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen2
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = Qwen2Model(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2
|
class Qwen2ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
Qwen2Model,
|
Qwen2Model,
|
||||||
@ -199,21 +72,20 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
model_tester_class = Qwen2ModelTester
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": Qwen2Model,
|
"feature-extraction": Qwen2Model,
|
||||||
"text-classification": Qwen2ForSequenceClassification,
|
"text-classification": Qwen2ForSequenceClassification,
|
||||||
"token-classification": Qwen2ForTokenClassification,
|
"token-classification": Qwen2ForTokenClassification,
|
||||||
"text-generation": Qwen2ForCausalLM,
|
"text-generation": Qwen2ForCausalLM,
|
||||||
"zero-shot": Qwen2ForSequenceClassification,
|
|
||||||
"question-answering": Qwen2ForQuestionAnswering,
|
"question-answering": Qwen2ForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
test_headmasking = False
|
|
||||||
test_pruning = False
|
|
||||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
@ -228,82 +100,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = Qwen2ModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=Qwen2Config, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_torch_fx_output_loss(self):
|
|
||||||
super().test_torch_fx_output_loss()
|
|
||||||
|
|
||||||
def test_Qwen2_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Qwen2ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Qwen2_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Qwen2ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Qwen2_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = Qwen2ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2
|
|
||||||
def test_Qwen2_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = Qwen2ForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -30,11 +30,6 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -48,173 +43,21 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeModelTester:
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=True,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=5,
|
|
||||||
max_window_layers=3,
|
|
||||||
use_sliding_window=True,
|
|
||||||
sliding_window=50,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
expert_interval=1,
|
|
||||||
moe_intermediate_size=12,
|
|
||||||
shared_expert_intermediate_size=36,
|
|
||||||
shared_expert_gate=True,
|
|
||||||
num_experts_per_tok=2,
|
|
||||||
num_experts=8,
|
|
||||||
norm_topk_prob=False,
|
|
||||||
output_router_logits=False,
|
|
||||||
router_aux_loss_coef=0.001,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=1,
|
|
||||||
scope=None,
|
|
||||||
qkv_bias=False,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.max_window_layers = max_window_layers
|
|
||||||
self.use_sliding_window = use_sliding_window
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.scope = scope
|
|
||||||
self.expert_interval = expert_interval
|
|
||||||
self.moe_intermediate_size = moe_intermediate_size
|
|
||||||
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
|
||||||
self.shared_expert_gate = shared_expert_gate
|
|
||||||
self.num_experts_per_tok = num_experts_per_tok
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.norm_topk_prob = norm_topk_prob
|
|
||||||
self.output_router_logits = output_router_logits
|
|
||||||
self.router_aux_loss_coef = router_aux_loss_coef
|
|
||||||
self.qkv_bias = qkv_bias
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
class Qwen2MoeModelTester(CausalLMModelTester):
|
||||||
if self.use_input_mask:
|
config_class = Qwen2MoeConfig
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
if is_torch_available():
|
||||||
|
base_model_class = Qwen2MoeModel
|
||||||
token_type_ids = None
|
causal_lm_class = Qwen2MoeForCausalLM
|
||||||
if self.use_token_type_ids:
|
sequence_class = Qwen2MoeForSequenceClassification
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
token_class = Qwen2MoeForTokenClassification
|
||||||
|
question_answering_class = Qwen2MoeForQuestionAnswering
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return Qwen2MoeConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
max_window_layers=self.max_window_layers,
|
|
||||||
use_sliding_window=self.use_sliding_window,
|
|
||||||
sliding_window=self.sliding_window,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
expert_interval=self.expert_interval,
|
|
||||||
moe_intermediate_size=self.moe_intermediate_size,
|
|
||||||
shared_expert_intermediate_size=self.shared_expert_intermediate_size,
|
|
||||||
shared_expert_gate=self.shared_expert_gate,
|
|
||||||
num_experts_per_tok=self.num_experts_per_tok,
|
|
||||||
num_experts=self.num_experts,
|
|
||||||
norm_topk_prob=self.norm_topk_prob,
|
|
||||||
output_router_logits=self.output_router_logits,
|
|
||||||
router_aux_loss_coef=self.router_aux_loss_coef,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
bos_token_id=self.bos_token_id,
|
|
||||||
qkv_bias=self.qkv_bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen2Moe
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = Qwen2MoeModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe
|
class Qwen2MoeModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
Qwen2MoeModel,
|
Qwen2MoeModel,
|
||||||
@ -232,15 +75,15 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
"text-classification": Qwen2MoeForSequenceClassification,
|
"text-classification": Qwen2MoeForSequenceClassification,
|
||||||
"token-classification": Qwen2MoeForTokenClassification,
|
"token-classification": Qwen2MoeForTokenClassification,
|
||||||
"text-generation": Qwen2MoeForCausalLM,
|
"text-generation": Qwen2MoeForCausalLM,
|
||||||
"zero-shot": Qwen2MoeForSequenceClassification,
|
|
||||||
"question-answering": Qwen2MoeForQuestionAnswering,
|
"question-answering": Qwen2MoeForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
model_tester_class = Qwen2MoeModelTester
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
@ -255,82 +98,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = Qwen2MoeModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=Qwen2MoeConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_torch_fx_output_loss(self):
|
|
||||||
super().test_torch_fx_output_loss()
|
|
||||||
|
|
||||||
def test_Qwen2Moe_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Qwen2MoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Qwen2Moe_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Qwen2MoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Qwen2Moe_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = Qwen2MoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe
|
|
||||||
def test_Qwen2Moe_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = Qwen2MoeForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -33,11 +33,6 @@ from transformers.testing_utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.import_utils import is_torch_greater_or_equal
|
from transformers.utils.import_utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -50,147 +45,21 @@ if is_torch_available():
|
|||||||
Qwen3Model,
|
Qwen3Model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
|
||||||
class Qwen3ModelTester:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=True,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=64,
|
|
||||||
num_hidden_layers=5,
|
|
||||||
max_window_layers=3,
|
|
||||||
use_sliding_window=True,
|
|
||||||
sliding_window=50,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
head_dim=16,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=1,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.max_window_layers = max_window_layers
|
|
||||||
self.use_sliding_window = use_sliding_window
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
class Qwen3ModelTester(CausalLMModelTester):
|
||||||
def prepare_config_and_inputs(self):
|
config_class = Qwen3Config
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
if is_torch_available():
|
||||||
|
base_model_class = Qwen3Model
|
||||||
input_mask = None
|
causal_lm_class = Qwen3ForCausalLM
|
||||||
if self.use_input_mask:
|
sequence_class = Qwen3ForSequenceClassification
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
token_class = Qwen3ForTokenClassification
|
||||||
|
question_answering_class = Qwen3ForQuestionAnswering
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return Qwen3Config(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
max_window_layers=self.max_window_layers,
|
|
||||||
use_sliding_window=self.use_sliding_window,
|
|
||||||
sliding_window=self.sliding_window,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
head_dim=self.head_dim,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
bos_token_id=self.bos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen3
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = Qwen3Model(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen3
|
class Qwen3ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
Qwen3Model,
|
Qwen3Model,
|
||||||
@ -202,21 +71,20 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
model_tester_class = Qwen3ModelTester
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": Qwen3Model,
|
"feature-extraction": Qwen3Model,
|
||||||
"text-classification": Qwen3ForSequenceClassification,
|
"text-classification": Qwen3ForSequenceClassification,
|
||||||
"token-classification": Qwen3ForTokenClassification,
|
"token-classification": Qwen3ForTokenClassification,
|
||||||
"text-generation": Qwen3ForCausalLM,
|
"text-generation": Qwen3ForCausalLM,
|
||||||
"zero-shot": Qwen3ForSequenceClassification,
|
|
||||||
"question-answering": Qwen3ForQuestionAnswering,
|
"question-answering": Qwen3ForQuestionAnswering,
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
test_headmasking = False
|
|
||||||
test_pruning = False
|
|
||||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
@ -231,82 +99,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = Qwen3ModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=Qwen3Config, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_torch_fx_output_loss(self):
|
|
||||||
super().test_torch_fx_output_loss()
|
|
||||||
|
|
||||||
def test_Qwen3_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Qwen3ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Qwen3_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Qwen3ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Qwen3_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = Qwen3ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen3,llama->Qwen3
|
|
||||||
def test_Qwen3_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = Qwen3ForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -30,185 +30,33 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
Qwen3ForQuestionAnswering,
|
||||||
Qwen3MoeForCausalLM,
|
Qwen3MoeForCausalLM,
|
||||||
Qwen3MoeForQuestionAnswering,
|
Qwen3MoeForQuestionAnswering,
|
||||||
Qwen3MoeForSequenceClassification,
|
Qwen3MoeForSequenceClassification,
|
||||||
Qwen3MoeForTokenClassification,
|
Qwen3MoeForTokenClassification,
|
||||||
Qwen3MoeModel,
|
Qwen3MoeModel,
|
||||||
)
|
)
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MoeModelTester:
|
class Qwen3MoeModelTester(CausalLMModelTester):
|
||||||
def __init__(
|
config_class = Qwen3MoeConfig
|
||||||
self,
|
if is_torch_available():
|
||||||
parent,
|
base_model_class = Qwen3MoeModel
|
||||||
batch_size=13,
|
causal_lm_class = Qwen3MoeForCausalLM
|
||||||
seq_length=7,
|
sequence_class = Qwen3MoeForSequenceClassification
|
||||||
is_training=True,
|
token_class = Qwen3MoeForTokenClassification
|
||||||
use_input_mask=True,
|
question_answering_class = Qwen3MoeForQuestionAnswering
|
||||||
use_token_type_ids=True,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=64,
|
|
||||||
num_hidden_layers=5,
|
|
||||||
max_window_layers=3,
|
|
||||||
use_sliding_window=True,
|
|
||||||
sliding_window=50,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
head_dim=16,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
expert_interval=1,
|
|
||||||
moe_intermediate_size=12,
|
|
||||||
num_experts_per_tok=2,
|
|
||||||
num_experts=8,
|
|
||||||
norm_topk_prob=False,
|
|
||||||
output_router_logits=False,
|
|
||||||
router_aux_loss_coef=0.001,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
bos_token_id=1,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.max_window_layers = max_window_layers
|
|
||||||
self.use_sliding_window = use_sliding_window
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.scope = scope
|
|
||||||
self.expert_interval = expert_interval
|
|
||||||
self.moe_intermediate_size = moe_intermediate_size
|
|
||||||
self.num_experts_per_tok = num_experts_per_tok
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.norm_topk_prob = norm_topk_prob
|
|
||||||
self.output_router_logits = output_router_logits
|
|
||||||
self.router_aux_loss_coef = router_aux_loss_coef
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return Qwen3MoeConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
max_window_layers=self.max_window_layers,
|
|
||||||
use_sliding_window=self.use_sliding_window,
|
|
||||||
sliding_window=self.sliding_window,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
head_dim=self.head_dim,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
expert_interval=self.expert_interval,
|
|
||||||
moe_intermediate_size=self.moe_intermediate_size,
|
|
||||||
num_experts_per_tok=self.num_experts_per_tok,
|
|
||||||
num_experts=self.num_experts,
|
|
||||||
norm_topk_prob=self.norm_topk_prob,
|
|
||||||
output_router_logits=self.output_router_logits,
|
|
||||||
router_aux_loss_coef=self.router_aux_loss_coef,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
bos_token_id=self.bos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen3Moe
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = Qwen3MoeModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen3Moe
|
class Qwen3MoeModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(
|
(
|
||||||
Qwen3MoeModel,
|
Qwen3MoeModel,
|
||||||
@ -226,15 +74,15 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
"text-classification": Qwen3MoeForSequenceClassification,
|
"text-classification": Qwen3MoeForSequenceClassification,
|
||||||
"token-classification": Qwen3MoeForTokenClassification,
|
"token-classification": Qwen3MoeForTokenClassification,
|
||||||
"text-generation": Qwen3MoeForCausalLM,
|
"text-generation": Qwen3MoeForCausalLM,
|
||||||
"zero-shot": Qwen3MoeForSequenceClassification,
|
"question-answering": Qwen3ForQuestionAnswering,
|
||||||
"question-answering": Qwen3MoeForQuestionAnswering,
|
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
model_tester_class = Qwen3MoeModelTester
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
||||||
def is_pipeline_test_to_skip(
|
def is_pipeline_test_to_skip(
|
||||||
@ -249,82 +97,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = Qwen3MoeModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=Qwen3MoeConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_torch_fx_output_loss(self):
|
|
||||||
super().test_torch_fx_output_loss()
|
|
||||||
|
|
||||||
def test_Qwen3Moe_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Qwen3MoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Qwen3Moe_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Qwen3MoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Qwen3Moe_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = Qwen3MoeForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen3Moe,llama->Qwen3Moe
|
|
||||||
def test_Qwen3Moe_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = Qwen3MoeForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@ -27,151 +28,26 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import RecurrentGemmaForCausalLM, RecurrentGemmaModel
|
from transformers import RecurrentGemmaConfig, RecurrentGemmaForCausalLM, RecurrentGemmaModel
|
||||||
|
|
||||||
|
|
||||||
class RecurrentGemmaModelTester:
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=12,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
num_hidden_layers=3,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
intermediate_size=3 * 32,
|
|
||||||
num_attention_heads=2,
|
|
||||||
lru_width=2 * 32,
|
|
||||||
embeddings_scale_by_sqrt_dim=True,
|
|
||||||
attention_window_size=16,
|
|
||||||
conv1d_width=4,
|
|
||||||
logits_soft_cap=30.0,
|
|
||||||
rms_norm_eps=1e-6,
|
|
||||||
use_cache=True,
|
|
||||||
rope_theta=10000.0,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.lru_width = lru_width if lru_width is not None else hidden_size
|
|
||||||
self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim
|
|
||||||
self.attention_window_size = attention_window_size
|
|
||||||
self.conv1d_width = conv1d_width
|
|
||||||
self.logits_soft_cap = logits_soft_cap
|
|
||||||
self.rms_norm_eps = rms_norm_eps
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
|
|
||||||
self.type_vocab_size = type_vocab_size
|
class RecurrentGemmaModelTester(CausalLMModelTester):
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
config_class = RecurrentGemmaConfig
|
||||||
self.num_labels = num_labels
|
if is_torch_available():
|
||||||
self.num_choices = num_choices
|
base_model_class = RecurrentGemmaModel
|
||||||
self.pad_token_id = pad_token_id
|
causal_lm_class = RecurrentGemmaForCausalLM
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
|
|
||||||
def prepare_config_and_inputs(self):
|
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
||||||
|
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return RecurrentGemmaConfig(
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
lru_width=self.lru_width,
|
|
||||||
embeddings_scale_by_sqrt_dim=self.embeddings_scale_by_sqrt_dim,
|
|
||||||
attention_window_size=self.attention_window_size,
|
|
||||||
conv1d_width=self.conv1d_width,
|
|
||||||
logits_soft_cap=self.logits_soft_cap,
|
|
||||||
rms_norm_eps=self.rms_norm_eps,
|
|
||||||
use_cache=self.use_cache,
|
|
||||||
rope_theta=self.rope_theta,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
output_attentions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->RecurrentGemma
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = RecurrentGemmaModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->RecurrentGemma
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class RecurrentGemmaModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else ()
|
all_model_classes = (RecurrentGemmaModel, RecurrentGemmaForCausalLM) if is_torch_available() else ()
|
||||||
# Doesn't run generation tests. TODO @gante not fully supported
|
|
||||||
all_generative_model_classes = ()
|
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": RecurrentGemmaModel,
|
"feature-extraction": RecurrentGemmaModel,
|
||||||
@ -180,48 +56,10 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
fx_compatible = False # FIXME let's try to support this @ArthurZucker
|
test_headmasking = False
|
||||||
test_torchscript = False # FIXME let's try to support this @ArthurZucker
|
|
||||||
test_missing_keys = False
|
|
||||||
test_model_parallel = False
|
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_head_masking = False # RecurrentGemma does not have attention heads
|
has_attentions = False
|
||||||
|
model_tester_class = RecurrentGemmaModelTester
|
||||||
# Need to remove 0.9 in `test_cpu_offload`
|
|
||||||
# This is because we are hitting edge cases with the causal_mask buffer
|
|
||||||
model_split_percents = [0.5, 0.6]
|
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
|
||||||
def is_pipeline_test_to_skip(
|
|
||||||
self,
|
|
||||||
pipeline_test_case_name,
|
|
||||||
config_class,
|
|
||||||
model_architecture,
|
|
||||||
tokenizer_name,
|
|
||||||
image_processor_name,
|
|
||||||
feature_extractor_name,
|
|
||||||
processor_name,
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# We don't output attentions
|
|
||||||
self.has_attentions = False
|
|
||||||
self.model_tester = RecurrentGemmaModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=RecurrentGemmaConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
@unittest.skip(reason="RecurrentGemma only supports sdpa")
|
@unittest.skip(reason="RecurrentGemma only supports sdpa")
|
||||||
def test_eager_matches_sdpa_generate(self):
|
def test_eager_matches_sdpa_generate(self):
|
||||||
@ -255,6 +93,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
def test_model_parallel_beam_search(self):
|
def test_model_parallel_beam_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@parameterized.expand([("random",), ("same",)])
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
||||||
def test_assisted_decoding_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self):
|
||||||
@ -273,6 +112,65 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
|
|||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_beam_sample_generate_dict_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_beam_search_generate_dict_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_constrained_beam_search_generate_dict_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_dola_decoding_sample(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_generate_without_input_ids(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_group_beam_search_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_group_beam_search_generate_dict_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_constrained_beam_search_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_greedy_generate_dict_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
|
||||||
|
def test_model_outputs_equivalence(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@slow
|
@slow
|
||||||
|
@ -16,9 +16,8 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import StableLmConfig, is_torch_available, set_seed
|
from transformers import StableLmConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@ -27,11 +26,6 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -45,133 +39,27 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
from transformers.models.stablelm.modeling_stablelm import StableLmRotaryEmbedding
|
from transformers.models.stablelm.modeling_stablelm import StableLmRotaryEmbedding
|
||||||
|
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
|
||||||
# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm
|
|
||||||
class StableLmModelTester:
|
|
||||||
# Ignore copy
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=64,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=4,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
class StableLmModelTester(CausalLMModelTester):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
if is_torch_available():
|
||||||
|
config_class = StableLmConfig
|
||||||
input_mask = None
|
base_model_class = StableLmModel
|
||||||
if self.use_input_mask:
|
causal_lm_class = StableLmForCausalLM
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
sequence_class = StableLmForSequenceClassification
|
||||||
|
token_class = StableLmForTokenClassification
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
return StableLmConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = StableLmModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm
|
class StableLmModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification)
|
(
|
||||||
|
StableLmModel,
|
||||||
|
StableLmForCausalLM,
|
||||||
|
StableLmForSequenceClassification,
|
||||||
|
StableLmForTokenClassification,
|
||||||
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
@ -179,167 +67,18 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
{
|
{
|
||||||
"feature-extraction": StableLmModel,
|
"feature-extraction": StableLmModel,
|
||||||
"text-classification": StableLmForSequenceClassification,
|
"text-classification": StableLmForSequenceClassification,
|
||||||
|
"text-generation": StableLmForCausalLM,
|
||||||
|
"zero-shot": StableLmForSequenceClassification,
|
||||||
"token-classification": StableLmForTokenClassification,
|
"token-classification": StableLmForTokenClassification,
|
||||||
# TODO (ydshieh): check why these two fail. Fix them or skip them in a better way.
|
|
||||||
# "text-generation": StableLmForCausalLM,
|
|
||||||
# "zero-shot": StableLmForSequenceClassification,
|
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
fx_compatible = False # Broken by attention refactor cc @Cyrilvallez
|
||||||
def setUp(self):
|
model_tester_class = StableLmModelTester
|
||||||
self.model_tester = StableLmModelTester(self)
|
rotary_embedding_layer = StableLmRotaryEmbedding # Enables RoPE tests if set
|
||||||
self.config_tester = ConfigTester(self, config_class=StableLmConfig, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_stablelm_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = StableLmForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_stablelm_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = StableLmForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_stablelm_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = StableLmForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm
|
|
||||||
def test_stablelm_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = StableLmForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@parameterized.expand([("linear",), ("dynamic",)])
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm
|
|
||||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
|
||||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
original_model = StableLmModel(config)
|
|
||||||
original_model.to(torch_device)
|
|
||||||
original_model.eval()
|
|
||||||
original_short_output = original_model(short_input).last_hidden_state
|
|
||||||
original_long_output = original_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
|
||||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
|
||||||
scaled_model = StableLmModel(config)
|
|
||||||
scaled_model.to(torch_device)
|
|
||||||
scaled_model.eval()
|
|
||||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
|
||||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
|
||||||
|
|
||||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
|
||||||
# maximum sequence length, so the outputs for the short input should match.
|
|
||||||
if scaling_type == "dynamic":
|
|
||||||
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
|
||||||
else:
|
|
||||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
|
||||||
|
|
||||||
# The output should be different for long inputs
|
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
|
||||||
|
|
||||||
# Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->StableLm
|
|
||||||
def test_model_rope_scaling(self):
|
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
scaling_factor = 10
|
|
||||||
short_input_length = 10
|
|
||||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
x = torch.randn(
|
|
||||||
1, dtype=torch.float32, device=torch_device
|
|
||||||
) # used exclusively to get the dtype and the device
|
|
||||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_short = position_ids_short.unsqueeze(0)
|
|
||||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
|
||||||
position_ids_long = position_ids_long.unsqueeze(0)
|
|
||||||
|
|
||||||
# Sanity check original RoPE
|
|
||||||
original_rope = StableLmRotaryEmbedding(config).to(torch_device)
|
|
||||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
|
||||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
|
||||||
|
|
||||||
# Sanity check linear RoPE scaling
|
|
||||||
# New position "x" should match original position with index "x/scaling_factor"
|
|
||||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
|
||||||
linear_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device)
|
|
||||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
|
||||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
|
||||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
|
||||||
for new_position in range(0, long_input_length, scaling_factor):
|
|
||||||
original_position = int(new_position // scaling_factor)
|
|
||||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
|
||||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
|
||||||
|
|
||||||
# Sanity check Dynamic NTK RoPE scaling
|
|
||||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
|
||||||
# with scaling_factor (or that `inv_freq` decreases)
|
|
||||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
|
||||||
ntk_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device)
|
|
||||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
|
||||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
|
||||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
|
||||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
|
||||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@ -28,11 +28,6 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
@ -45,241 +40,38 @@ if is_torch_available():
|
|||||||
Starcoder2Model,
|
Starcoder2Model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
|
|
||||||
# Copied from transformers.tests.models.mistral.test_modeling_mistral.Starcoder2ModelTester with Mistral->Starcoder2
|
|
||||||
class Starcoder2ModelTester:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parent,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
vocab_size=99,
|
|
||||||
hidden_size=32,
|
|
||||||
num_hidden_layers=2,
|
|
||||||
num_attention_heads=4,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
intermediate_size=37,
|
|
||||||
hidden_act="gelu",
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_probs_dropout_prob=0.1,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
initializer_range=0.02,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
pad_token_id=0,
|
|
||||||
scope=None,
|
|
||||||
):
|
|
||||||
self.parent = parent
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.seq_length = seq_length
|
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.num_key_value_heads = num_key_value_heads
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.hidden_act = hidden_act
|
|
||||||
self.hidden_dropout_prob = hidden_dropout_prob
|
|
||||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.type_vocab_size = type_vocab_size
|
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.num_labels = num_labels
|
|
||||||
self.num_choices = num_choices
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
self.scope = scope
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs
|
class Starcoder2ModelTester(CausalLMModelTester):
|
||||||
def prepare_config_and_inputs(self):
|
config_class = Starcoder2Config
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
if is_torch_available():
|
||||||
|
base_model_class = Starcoder2Model
|
||||||
input_mask = None
|
causal_lm_class = Starcoder2ForCausalLM
|
||||||
if self.use_input_mask:
|
sequence_class = Starcoder2ForSequenceClassification
|
||||||
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
token_class = Starcoder2ForTokenClassification
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
# Ignore copy
|
|
||||||
def get_config(self):
|
|
||||||
return Starcoder2Config(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
|
||||||
num_attention_heads=self.num_attention_heads,
|
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
hidden_act=self.hidden_act,
|
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
is_decoder=False,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
eos_token_id=self.pad_token_id,
|
|
||||||
bos_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Starcoder2
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = Starcoder2Model(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
# Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2
|
class Starcoder2ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification)
|
(Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
|
test_headmasking = False
|
||||||
|
test_pruning = False
|
||||||
|
model_tester_class = Starcoder2ModelTester
|
||||||
pipeline_model_mapping = (
|
pipeline_model_mapping = (
|
||||||
{
|
{
|
||||||
"feature-extraction": Starcoder2Model,
|
"feature-extraction": Starcoder2Model,
|
||||||
"text-classification": Starcoder2ForSequenceClassification,
|
"text-classification": Starcoder2ForSequenceClassification,
|
||||||
"token-classification": Starcoder2ForTokenClassification,
|
"token-classification": Starcoder2ForTokenClassification,
|
||||||
"text-generation": Starcoder2ForCausalLM,
|
"text-generation": Starcoder2ForCausalLM,
|
||||||
"zero-shot": Starcoder2ForSequenceClassification,
|
|
||||||
}
|
}
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else {}
|
else {}
|
||||||
)
|
)
|
||||||
test_headmasking = False
|
|
||||||
test_pruning = False
|
|
||||||
|
|
||||||
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
|
|
||||||
def is_pipeline_test_to_skip(
|
|
||||||
self,
|
|
||||||
pipeline_test_case_name,
|
|
||||||
config_class,
|
|
||||||
model_architecture,
|
|
||||||
tokenizer_name,
|
|
||||||
image_processor_name,
|
|
||||||
feature_extractor_name,
|
|
||||||
processor_name,
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.model_tester = Starcoder2ModelTester(self)
|
|
||||||
self.config_tester = ConfigTester(self, config_class=Starcoder2Config, hidden_size=37)
|
|
||||||
|
|
||||||
def test_config(self):
|
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
||||||
config_and_inputs[0].position_embedding_type = type
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_Starcoder2_sequence_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
print(config)
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Starcoder2ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Starcoder2_sequence_classification_model_for_single_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "single_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
|
||||||
model = Starcoder2ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
def test_Starcoder2_sequence_classification_model_for_multi_label(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
config.problem_type = "multi_label_classification"
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
sequence_labels = ids_tensor(
|
|
||||||
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
|
||||||
).to(torch.float)
|
|
||||||
model = Starcoder2ForSequenceClassification(config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
|
||||||
|
|
||||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2
|
|
||||||
def test_Starcoder2_token_classification_model(self):
|
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
config.num_labels = 3
|
|
||||||
input_ids = input_dict["input_ids"]
|
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
|
||||||
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
|
||||||
model = Starcoder2ForTokenClassification(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
|
||||||
self.assertEqual(
|
|
||||||
result.logits.shape,
|
|
||||||
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
|
@ -4426,7 +4426,7 @@ class ModelTesterMixin:
|
|||||||
# comparing softmax-normalized logits:
|
# comparing softmax-normalized logits:
|
||||||
normalized_0 = F.softmax(out_last_tokens, dim=-1)
|
normalized_0 = F.softmax(out_last_tokens, dim=-1)
|
||||||
normalized_1 = F.softmax(out_shared_prefix_last_tokens, dim=-1)
|
normalized_1 = F.softmax(out_shared_prefix_last_tokens, dim=-1)
|
||||||
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
|
Loading…
Reference in New Issue
Block a user