# coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Testing suite for the PyTorch chameleon model.""" import unittest import pytest import requests from parameterized import parameterized from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, require_read_token, require_torch, require_torch_gpu, slow, torch_device, ) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin if is_vision_available(): from PIL import Image if is_torch_available(): import torch from transformers import ( ChameleonForConditionalGeneration, ChameleonModel, ChameleonProcessor, ) class ChameleonModelTester: def __init__( self, parent, batch_size=13, seq_length=7, is_training=False, use_input_mask=True, use_labels=True, vocab_size=99, image_token_id=98, hidden_size=32, num_hidden_layers=2, num_attention_heads=2, num_key_value_heads=2, intermediate_size=37, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=16, type_sequence_label_size=2, initializer_range=0.02, num_labels=3, num_choices=4, pad_token_id=0, vq_num_embeds=12, vq_embed_dim=12, vq_channel_multiplier=[1, 2], vq_img_token_start_id=10, # has to be less than vocab size when added with vq_num_embeds scope=None, ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training self.use_input_mask = use_input_mask self.use_labels = use_labels self.vocab_size = vocab_size self.image_token_id = image_token_id self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.num_labels = num_labels self.num_choices = num_choices self.pad_token_id = pad_token_id self.scope = scope self.vq_num_embeds = vq_num_embeds self.vq_embed_dim = vq_embed_dim self.vq_channel_multiplier = vq_channel_multiplier self.vq_img_token_start_id = vq_img_token_start_id def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None if self.use_input_mask: input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) sequence_labels = None token_labels = None choice_labels = None if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) config = self.get_config() return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels def get_config(self): # create dummy vocab map for image2bpe mapping if it needs remapping # we assume that vocab size is big enough to accoun for image tokens somewhere in the beginning # same way as in real ckpt, when img tokens are in first half of embeds # we will need "vq_num_embeds" amount of tokens vocab_map = {i: chr(i) for i in range(self.vocab_size)} vocab_map[self.image_token_id] = "" start = self.vq_img_token_start_id end = self.vq_img_token_start_id + self.vq_num_embeds for i in range(start, end): vocab_map[i] = f"IMGIMGBS{i}" # dummy str for each token, anything starting with IMGIMG return ChameleonConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, num_key_value_heads=self.num_key_value_heads, intermediate_size=self.intermediate_size, hidden_act=self.hidden_act, hidden_dropout_prob=self.hidden_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, is_decoder=False, initializer_range=self.initializer_range, pad_token_id=self.pad_token_id, vocabulary_map={v: k for k, v in vocab_map.items()}, vq_config=self.get_vq_config(), ) def get_vq_config(self): return { "embed_dim": self.vq_embed_dim, "num_embeddings": self.vq_num_embeds, "latent_channels": self.vq_embed_dim, "in_channels": 3, "base_channels": 32, # we have a GroupNorm of 32 groups, so can't do less "channel_multiplier": self.vq_channel_multiplier, } def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): model = ChameleonModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask) result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) def create_and_check_for_causal_lm( self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): model = ChameleonForConditionalGeneration(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def create_and_check_decoder_model_past_large_inputs( self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels, encoder_hidden_states, encoder_attention_mask, ): config.is_decoder = True model = ChameleonForConditionalGeneration(config=config) model.to(torch_device) model.eval() # first forward pass outputs = model( input_ids, attention_mask=input_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=True, ) past_key_values = outputs.past_key_values # create hypothetical multiple next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) # append to next input_ids and next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) output_from_no_past = model( next_input_ids, attention_mask=next_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_hidden_states=True, )["hidden_states"][0] output_from_past = model( next_tokens, attention_mask=next_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, output_hidden_states=True, )["hidden_states"][0] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( config, input_ids, input_mask, sequence_labels, token_labels, choice_labels, ) = config_and_inputs inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} return config, inputs_dict @require_torch class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (ChameleonForConditionalGeneration,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": ChameleonModel, "text-generation": ChameleonForConditionalGeneration, } if is_torch_available() else {} ) test_headmasking = False test_pruning = False fx_compatible = False def setUp(self): self.model_tester = ChameleonModelTester(self) self.config_tester = ConfigTester(self, config_class=ChameleonConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) set_seed(42) # Fixed seed at init time so the two models get the same random weights original_model = ChameleonModel(config) original_model.to(torch_device) original_model.eval() original_short_output = original_model(short_input).last_hidden_state original_long_output = original_model(long_input).last_hidden_state set_seed(42) # Fixed seed at init time so the two models get the same random weights config.rope_scaling = {"type": scaling_type, "factor": 10.0} scaled_model = ChameleonModel(config) scaled_model.to(torch_device) scaled_model.eval() scaled_short_output = scaled_model(short_input).last_hidden_state scaled_long_output = scaled_model(long_input).last_hidden_state # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original # maximum sequence length, so the outputs for the short input should match. if scaling_type == "dynamic": self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) else: self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) @require_flash_attn @require_read_token @require_torch_gpu @require_bitsandbytes @pytest.mark.flash_attn_test @slow def test_flash_attn_2_generate_padding_right(self): """ Overwritting the common test as the test is flaky on tiny models """ model = ChameleonForConditionalGeneration.from_pretrained( "facebook/chameleon-7b", load_in_4bit=True, device_map={"": 0}, ) processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") texts = ["hi", "Hello this is a very long sentence"] processor.tokenizer.padding_side = "right" inputs = processor(text=texts, return_tensors="pt", padding=True).to(0) output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_native = processor.tokenizer.batch_decode(output_native) model = ChameleonForConditionalGeneration.from_pretrained( "facebook/chameleon-7b", load_in_4bit=True, attn_implementation="flash_attention_2", ) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_fa_2 = processor.tokenizer.batch_decode(output_fa_2) self.assertListEqual(output_native, output_fa_2) @unittest.skip("Chameleon forces some token ids to be -inf!") def test_batching_equivalence(self): pass # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow @unittest.skip("Chameleon is not compatible with end-to-end generation compilation") def test_generate_compile_fullgraph(self): pass @require_torch class ChameleonIntegrationTest(unittest.TestCase): @slow @require_bitsandbytes @require_read_token def test_model_7b(self): model = ChameleonForConditionalGeneration.from_pretrained( "facebook/chameleon-7b", load_in_4bit=True, device_map="auto" ) processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") image = Image.open( requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw ) prompt = "Describe what do you see here and tell me about the history behind it?" inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16) # greedy generation outputs EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and'] # fmt: skip generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) text = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow @require_bitsandbytes @require_read_token def test_model_7b_batched(self): model = ChameleonForConditionalGeneration.from_pretrained( "facebook/chameleon-7b", load_in_4bit=True, device_map="auto" ) processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") image = Image.open( requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw ) image_2 = Image.open( requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw ) prompts = [ "Describe what do you see here and tell me about the history behind it?", "What constellation is this image showing?", ] inputs = processor(images=[image, image_2], text=prompts, padding=True, return_tensors="pt").to( model.device, torch.float16 ) # greedy generation outputs EXPECTED_TEXT_COMPLETION = [ 'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in', 'What constellation is this image showing?The image is showing the constellation of Orion.' ] # fmt: skip generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) text = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) @slow @require_bitsandbytes @require_read_token def test_model_7b_multi_image(self): model = ChameleonForConditionalGeneration.from_pretrained( "facebook/chameleon-7b", load_in_4bit=True, device_map="auto" ) processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") image = Image.open( requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw ) image_2 = Image.open( requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw ) prompt = "What do these two images have in common?" inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.float16) # greedy generation outputs EXPECTED_TEXT_COMPLETION = ['What do these two images have in common?The two images show a connection between two things that are not necessarily related. The first image shows a group of stars, while the second image shows a network of lines connecting two points. The connection between'] # fmt: skip generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False) text = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text)