mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
255 lines
9.6 KiB
Python
255 lines
9.6 KiB
Python
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch Cohere model."""
|
|
|
|
import unittest
|
|
|
|
from transformers import CohereConfig, is_torch_available
|
|
from transformers.testing_utils import (
|
|
require_bitsandbytes,
|
|
require_torch,
|
|
require_torch_multi_gpu,
|
|
require_torch_sdpa,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import AutoTokenizer, CohereForCausalLM, CohereModel
|
|
|
|
|
|
# Copied from transformers.tests.models.llama.LlamaModelTester with Llama->Cohere
|
|
class CohereModelTester:
|
|
config_class = CohereConfig
|
|
if is_torch_available():
|
|
model_class = CohereModel
|
|
for_causal_lm_class = CohereForCausalLM
|
|
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=13,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_token_type_ids=False,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=32,
|
|
num_hidden_layers=4,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
hidden_act="gelu",
|
|
hidden_dropout_prob=0.1,
|
|
attention_probs_dropout_prob=0.1,
|
|
max_position_embeddings=512,
|
|
type_vocab_size=16,
|
|
type_sequence_label_size=2,
|
|
initializer_range=0.02,
|
|
num_labels=3,
|
|
num_choices=4,
|
|
pad_token_id=0,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_token_type_ids = use_token_type_ids
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.hidden_act = hidden_act
|
|
self.hidden_dropout_prob = hidden_dropout_prob
|
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.type_vocab_size = type_vocab_size
|
|
self.type_sequence_label_size = type_sequence_label_size
|
|
self.initializer_range = initializer_range
|
|
self.num_labels = num_labels
|
|
self.num_choices = num_choices
|
|
self.pad_token_id = pad_token_id
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
|
|
|
token_type_ids = None
|
|
if self.use_token_type_ids:
|
|
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
|
|
sequence_labels = None
|
|
token_labels = None
|
|
choice_labels = None
|
|
if self.use_labels:
|
|
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
|
|
# Ignore copy
|
|
def get_config(self):
|
|
return self.config_class(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
hidden_act=self.hidden_act,
|
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
type_vocab_size=self.type_vocab_size,
|
|
is_decoder=False,
|
|
initializer_range=self.initializer_range,
|
|
pad_token_id=self.pad_token_id,
|
|
)
|
|
|
|
def create_and_check_model(
|
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
):
|
|
model = self.model_class(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
(
|
|
config,
|
|
input_ids,
|
|
token_type_ids,
|
|
input_mask,
|
|
sequence_labels,
|
|
token_labels,
|
|
choice_labels,
|
|
) = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (CohereModel, CohereForCausalLM) if is_torch_available() else ()
|
|
pipeline_model_mapping = (
|
|
{
|
|
"feature-extraction": CohereModel,
|
|
"text-generation": CohereForCausalLM,
|
|
}
|
|
if is_torch_available()
|
|
else {}
|
|
)
|
|
test_headmasking = False
|
|
test_pruning = False
|
|
fx_compatible = False
|
|
|
|
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
|
# This is because we are hitting edge cases with the causal_mask buffer
|
|
model_split_percents = [0.5, 0.7, 0.8]
|
|
|
|
def setUp(self):
|
|
self.model_tester = CohereModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=CohereConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_model_various_embeddings(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
for type in ["absolute", "relative_key", "relative_key_query"]:
|
|
config_and_inputs[0].position_embedding_type = type
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_torch_fx_output_loss(self):
|
|
super().test_torch_fx_output_loss()
|
|
|
|
|
|
@require_torch
|
|
@slow
|
|
class CohereIntegrationTest(unittest.TestCase):
|
|
@require_torch_multi_gpu
|
|
@require_bitsandbytes
|
|
def test_batched_4bit(self):
|
|
model_id = "CohereForAI/c4ai-command-r-v01-4bit"
|
|
|
|
EXPECTED_TEXT = [
|
|
'Hello today I am going to show you how to make a simple and easy card using the new stamp set called "Hello" from the Occasions catalog. This set is so versatile and can be used for many occasions. I used the new In',
|
|
"Hi there, here we are again with another great collection of free fonts for your next project. This time we have gathered 10 free fonts that you can download and use in your designs. These fonts are perfect for any kind",
|
|
]
|
|
|
|
model = CohereForCausalLM.from_pretrained(model_id, device_map="auto")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
text = ["Hello today I am going to show you how to", "Hi there, here we are"]
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)
|
|
|
|
output = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
|
self.assertEqual(tokenizer.batch_decode(output, skip_special_tokens=True), EXPECTED_TEXT)
|
|
|
|
@require_torch_sdpa
|
|
def test_batched_small_model_logits(self):
|
|
# Since the model is very large, we created a random cohere model so that we can do a simple
|
|
# logits check on it.
|
|
model_id = "hf-internal-testing/cohere-random"
|
|
|
|
EXPECTED_LOGITS = torch.Tensor(
|
|
[
|
|
[[0.0000, 0.0285, 0.0322], [0.0000, 0.0011, 0.1105], [0.0000, -0.0018, -0.1019]],
|
|
[[0.0000, 0.1080, 0.0454], [0.0000, -0.1808, -0.1553], [0.0000, 0.0452, 0.0369]],
|
|
]
|
|
).to(device=torch_device, dtype=torch.float16)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
model = CohereForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
|
|
torch_device
|
|
)
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
text = ["Hello today I am going to show you how to", "Hi there, here we are"]
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)
|
|
|
|
with torch.no_grad():
|
|
output = model(**inputs)
|
|
|
|
logits = output.logits
|
|
torch.testing.assert_close(EXPECTED_LOGITS, logits[:, -3:, :3], rtol=1e-3, atol=1e-3)
|