# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Testing suite for the PyTorch Mixtral model. """ import tempfile import unittest import pytest from transformers import MixtralConfig, is_torch_available from transformers.testing_utils import ( require_flash_attn, require_torch, require_torch_gpu, slow, torch_device, ) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): import torch from transformers import ( MixtralForCausalLM, MixtralForSequenceClassification, MixtralModel, ) class MixtralModelTester: # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.__init__ def __init__( self, parent, batch_size=13, seq_length=7, is_training=True, use_input_mask=True, use_token_type_ids=False, use_labels=True, vocab_size=99, hidden_size=32, num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, intermediate_size=37, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=16, type_sequence_label_size=2, initializer_range=0.02, num_labels=3, num_choices=4, pad_token_id=0, scope=None, ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training self.use_input_mask = use_input_mask self.use_token_type_ids = use_token_type_ids self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.num_labels = num_labels self.num_choices = num_choices self.pad_token_id = pad_token_id self.scope = scope # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None if self.use_input_mask: input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) token_type_ids = None if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) sequence_labels = None token_labels = None choice_labels = None if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) config = self.get_config() return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels def get_config(self): return MixtralConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, is_decoder=False, initializer_range=self.initializer_range, pad_token_id=self.pad_token_id, num_experts_per_tok=2, num_local_experts=2, ) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = MixtralModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask) result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Mixtral def create_and_check_model_as_decoder( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): config.add_cross_attention = True model = MixtralModel(config) model.to(torch_device) model.eval() result = model( input_ids, attention_mask=input_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, ) result = model( input_ids, attention_mask=input_mask, encoder_hidden_states=encoder_hidden_states, ) result = model(input_ids, attention_mask=input_mask) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Mixtral def create_and_check_for_causal_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): model = MixtralForCausalLM(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Mixtral def create_and_check_decoder_model_past_large_inputs( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): config.is_decoder = True config.add_cross_attention = True model = MixtralForCausalLM(config=config) model.to(torch_device) model.eval() # first forward pass outputs = model( input_ids, attention_mask=input_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=True, ) past_key_values = outputs.past_key_values # create hypothetical multiple next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) # append to next input_ids and next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) output_from_no_past = model( next_input_ids, attention_mask=next_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_hidden_states=True, )["hidden_states"][0] output_from_past = model( next_tokens, attention_mask=next_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, output_hidden_states=True, )["hidden_states"][0] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Mixtral def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, ) = config_and_inputs inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} return config, inputs_dict @require_torch # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( (MixtralModel, MixtralForCausalLM, MixtralForSequenceClassification) if is_torch_available() else () ) all_generative_model_classes = (MixtralForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": MixtralModel, "text-classification": MixtralForSequenceClassification, "text-generation": MixtralForCausalLM, "zero-shot": MixtralForSequenceClassification, } if is_torch_available() else {} ) test_headmasking = False test_pruning = False # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name ): return True def setUp(self): self.model_tester = MixtralModelTester(self) self.config_tester = ConfigTester(self, config_class=MixtralConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) def test_model_various_embeddings(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() for type in ["absolute", "relative_key", "relative_key_query"]: config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) def test_Mixtral_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() print(config) config.num_labels = 3 input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) model = MixtralForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) def test_Mixtral_sequence_classification_model_for_single_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 config.problem_type = "single_label_classification" input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) model = MixtralForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) def test_Mixtral_sequence_classification_model_for_multi_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 config.problem_type = "multi_label_classification" input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) sequence_labels = ids_tensor( [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size ).to(torch.float) model = MixtralForSequenceClassification(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) @unittest.skip("Mixtral buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass @unittest.skip("Mixtral uses GQA on all models so the KV cache is a non standard format") def test_past_key_values_format(self): pass @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @slow def test_flash_attn_2_generate_padding_right(self): import torch for model_class in self.all_generative_model_classes: config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( torch_device ) dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2", low_cpu_mem_usage=True, ).to(torch_device) with self.assertRaises(ValueError): _ = model.generate( dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @slow def test_flash_attn_2_generate_use_cache(self): import torch max_new_tokens = 30 for model_class in self.all_generative_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() dummy_input = inputs_dict[model_class.main_input_name] if dummy_input.dtype in [torch.float32, torch.bfloat16]: dummy_input = dummy_input.to(torch.float16) # make sure that all models have enough positions for generation if hasattr(config, "max_position_embeddings"): config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) # NOTE: Mixtral apparently does not support right padding + use_cache with FA2. dummy_attention_mask[:, -1] = 1 model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2", low_cpu_mem_usage=True, ).to(torch_device) # Just test that a large cache works as expected _ = model.generate( dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True, ) @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @slow def test_flash_attn_2_inference_padding_right(self): self.skipTest("Mixtral flash attention does not support right padding") # Ignore copy def test_load_balancing_loss(self): r""" Let's make sure we can actually compute the loss and do a backward on it. """ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 config.num_local_experts = 8 config.output_router_logits = True input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) model = MixtralForCausalLM(config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask) self.assertEqual(result.router_logits[0].shape, (91, config.num_local_experts)) torch.testing.assert_close(result.aux_loss.cpu(), torch.tensor(8, dtype=torch.float32)) @require_torch class MixtralIntegrationTest(unittest.TestCase): @slow @require_torch_gpu def test_small_model_logits(self): model_id = "hf-internal-testing/Mixtral-tiny" dummy_input = torch.LongTensor([[0, 1, 0], [0, 1, 0]]).to(torch_device) model = MixtralForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to( torch_device ) # TODO: might need to tweak it in case the logits do not match on our daily runners # these logits have been obtained with the original megablocks impelmentation. EXPECTED_LOGITS = torch.Tensor( [[0.1670, 0.1620, 0.6094], [-0.8906, -0.1588, -0.6060], [0.1572, 0.1290, 0.7246]] ).to(torch_device) with torch.no_grad(): logits = model(dummy_input).logits torch.testing.assert_close(logits[0, :3, :3].half(), EXPECTED_LOGITS, atol=1e-3, rtol=1e-3) torch.testing.assert_close(logits[1, :3, :3].half(), EXPECTED_LOGITS, atol=1e-3, rtol=1e-3) @slow # @require_torch_gpu def test_small_model_logits_batched(self): model_id = "hf-internal-testing/Mixtral-tiny" dummy_input = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3], [1, 1, 2, 3, 4, 5, 6, 7, 8]]).to(torch_device) attention_mask = dummy_input.ne(0).to(torch.long) model = MixtralForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to( torch_device ) # TODO: might need to tweak it in case the logits do not match on our daily runners EXPECTED_LOGITS_LEFT = torch.Tensor( [[0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007], [0.1750, 0.0537, 0.7007]], ) # logits[0, -3:, -3:].half() EXPECTED_LOGITS_LEFT_UNPADDED = torch.Tensor( [[0.2212, 0.5200, -0.3816], [0.8213, -0.2313, 0.6069], [0.2664, -0.7090, 0.2468]], ) # logits[1, -3:, -3:].half() EXPECTED_LOGITS_RIGHT_UNPADDED = torch.Tensor( [[0.2205, 0.1232, -0.1611], [-0.3484, 0.3030, -1.0312], [0.0742, 0.7930, 0.7969]] ) with torch.no_grad(): logits = model(dummy_input, attention_mask=attention_mask).logits torch.testing.assert_close(logits[0, :3, :3].half(), EXPECTED_LOGITS_LEFT, atol=1e-3, rtol=1e-3) torch.testing.assert_close(logits[0, -3:, -3:].half(), EXPECTED_LOGITS_LEFT_UNPADDED, atol=1e-3, rtol=1e-3) torch.testing.assert_close(logits[1, -3:, -3:].half(), EXPECTED_LOGITS_RIGHT_UNPADDED, atol=1e-3, rtol=1e-3)