transformers/tests/models/cohere/test_modeling_cohere.py
Cyril Vallez 896833c183
Fix some tests (especially compile with fullgraph=True on Python<3.11) (#38319)
* fix tests

* better fix for python<3.11

* fixes

* style
2025-05-23 17:11:40 +02:00

255 lines
9.6 KiB
Python

# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Cohere model."""
import unittest
from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import AutoTokenizer, CohereForCausalLM, CohereModel
# Copied from transformers.tests.models.llama.LlamaModelTester with Llama->Cohere
class CohereModelTester:
config_class = CohereConfig
if is_torch_available():
model_class = CohereModel
for_causal_lm_class = CohereForCausalLM
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config()
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
# Ignore copy
def get_config(self):
return self.config_class(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = self.model_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CohereModel, CohereForCausalLM) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": CohereModel,
"text-generation": CohereForCausalLM,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]
def setUp(self):
self.model_tester = CohereModelTester(self)
self.config_tester = ConfigTester(self, config_class=CohereConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_torch_fx_output_loss(self):
super().test_torch_fx_output_loss()
@require_torch
@slow
class CohereIntegrationTest(unittest.TestCase):
@require_torch_multi_gpu
@require_bitsandbytes
def test_batched_4bit(self):
model_id = "CohereForAI/c4ai-command-r-v01-4bit"
EXPECTED_TEXT = [
'Hello today I am going to show you how to make a simple and easy card using the new stamp set called "Hello" from the Occasions catalog. This set is so versatile and can be used for many occasions. I used the new In',
"Hi there, here we are again with another great collection of free fonts for your next project. This time we have gathered 10 free fonts that you can download and use in your designs. These fonts are perfect for any kind",
]
model = CohereForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
text = ["Hello today I am going to show you how to", "Hi there, here we are"]
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=40, do_sample=False)
self.assertEqual(tokenizer.batch_decode(output, skip_special_tokens=True), EXPECTED_TEXT)
@require_torch_sdpa
def test_batched_small_model_logits(self):
# Since the model is very large, we created a random cohere model so that we can do a simple
# logits check on it.
model_id = "hf-internal-testing/cohere-random"
EXPECTED_LOGITS = torch.Tensor(
[
[[0.0000, 0.0285, 0.0322], [0.0000, 0.0011, 0.1105], [0.0000, -0.0018, -0.1019]],
[[0.0000, 0.1080, 0.0454], [0.0000, -0.1808, -0.1553], [0.0000, 0.0452, 0.0369]],
]
).to(device=torch_device, dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = CohereForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
tokenizer.pad_token = tokenizer.eos_token
text = ["Hello today I am going to show you how to", "Hi there, here we are"]
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)
with torch.no_grad():
output = model(**inputs)
logits = output.logits
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, -3:, :3], rtol=1e-3, atol=1e-3)