From 57f551c78d681a5441d0f1f04e9a69d4984e8990 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 24 Mar 2025 12:36:08 +0100 Subject: [PATCH] [chameleon] fix num image token check (#36918) * [chameleon] fix num image token check * embed after merging image token * skip this also * mistral require_read_token --- .../models/chameleon/modeling_chameleon.py | 8 +- tests/generation/test_utils.py | 1 + .../chameleon/test_modeling_chameleon.py | 161 +++++++++++++++++- .../mistral3/test_processor_mistral3.py | 3 +- 4 files changed, 160 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 30a8ab60f20..33193365359 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1289,13 +1289,10 @@ class ChameleonModel(ChameleonPreTrainedModel): "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if pixel_values is not None: image_tokens = self.get_image_tokens(pixel_values) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel(): + if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel(): n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() n_image_features = image_tokens.shape[0] * image_tokens.shape[1] raise ValueError( @@ -1304,6 +1301,9 @@ class ChameleonModel(ChameleonPreTrainedModel): image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # torch.jit.trace() doesn't support cache objects in the output if use_cache and past_key_values is None and not torch.jit.is_tracing(): past_key_values = DynamicCache() diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c166bbeec12..f096d667d6d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -126,6 +126,7 @@ VLM_CLASS_NAMES = [ "ayavision", "gemma3", "mistral3", + "chameleon", ] diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 09eec986857..ae06de34ab2 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -14,6 +14,7 @@ # limitations under the License. """Testing suite for the PyTorch chameleon model.""" +import copy import unittest import requests @@ -30,7 +31,7 @@ from transformers.testing_utils import ( from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -52,12 +53,12 @@ class ChameleonModelTester: self, parent, batch_size=13, - seq_length=7, + seq_length=35, is_training=False, use_input_mask=True, use_labels=True, vocab_size=99, - image_token_id=98, + image_token_id=4, hidden_size=32, num_hidden_layers=2, num_attention_heads=2, @@ -73,9 +74,9 @@ class ChameleonModelTester: num_labels=3, num_choices=4, pad_token_id=0, - vq_num_embeds=12, - vq_embed_dim=12, - vq_channel_multiplier=[1, 2], + vq_num_embeds=5, + vq_embed_dim=5, + vq_channel_multiplier=[1, 4], vq_img_token_start_id=10, # has to be less than vocab size when added with vq_num_embeds scope=None, ): @@ -138,7 +139,9 @@ class ChameleonModelTester: start = self.vq_img_token_start_id end = self.vq_img_token_start_id + self.vq_num_embeds for i in range(start, end): - vocab_map[i] = f"IMGIMGBS{i}" # dummy str for each token, anything starting with IMGIMG + image_token_infix = "".join(chr(ord("A") + int(c)) for c in str(i)) + # dummy str for each image token, anything starting with IMGIMG + vocab_map[i] = f"IMGIMG{image_token_infix}Z" return ChameleonConfig( vocab_size=self.vocab_size, @@ -275,7 +278,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester { "feature-extraction": ChameleonModel, "text-generation": ChameleonForConditionalGeneration, - "image-text-to-text": ChameleonForConditionalGeneration, } if is_torch_available() else {} @@ -330,6 +332,149 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester def test_batching_equivalence(self): pass + @unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code") + def test_model_is_small(self): + pass + + +class ChameleonVision2SeqModelTester(ChameleonModelTester): + def __init__(self, parent, image_size=10, **kwargs): + super().__init__(parent, **kwargs) + self.image_size = image_size + self.image_seq_length = 25 + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids[input_ids == self.image_token_id] = self.pad_token_id + input_ids[:, : self.image_seq_length] = self.image_token_id + attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size]) + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else () + pipeline_model_mapping = ( + { + "image-text-to-text": ChameleonForConditionalGeneration, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = False + + def setUp(self): + self.model_tester = ChameleonVision2SeqModelTester(self) + self.config_tester = ConfigTester(self, config_class=ChameleonConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + @unittest.skip("Chameleon forces some token ids to be -inf!") + def test_batching_equivalence(self): + pass + + @unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward") + def test_cpu_offload(self): + pass + + @unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward") + def test_disk_offload_bin(self): + pass + + @unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward") + def test_disk_offload_safetensors(self): + pass + + @unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code") + def test_model_is_small(self): + pass + + def test_mismatching_num_image_tokens(self): + """ + Tests that VLMs through an error with explicit message saying what is wrong + when number of images don't match number of image tokens in the text. + Also we need to test multi-image cases when one prompr has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + curr_input_dict = copy.deepcopy(input_dict) # the below tests modify dict in-place + _ = model(**curr_input_dict) # successful forward with no modifications + + # remove one image but leave the image token in text + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...] + with self.assertRaises(ValueError): + _ = model(**curr_input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image/image-token + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # one image and two image tokens raise an error + with self.assertRaises(ValueError): + _ = model(input_ids=input_ids, pixel_values=pixel_values) + + # two images and two image tokens don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + _ = model(input_ids=input_ids, pixel_values=pixel_values) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + torch.testing.assert_close(out_embeds, out_ids) + @require_torch class ChameleonIntegrationTest(unittest.TestCase): diff --git a/tests/models/mistral3/test_processor_mistral3.py b/tests/models/mistral3/test_processor_mistral3.py index 0a5eeaa99c7..4bebc8a3ad5 100644 --- a/tests/models/mistral3/test_processor_mistral3.py +++ b/tests/models/mistral3/test_processor_mistral3.py @@ -20,7 +20,7 @@ import unittest import requests from transformers import PixtralProcessor -from transformers.testing_utils import require_vision +from transformers.testing_utils import require_read_token, require_vision from transformers.utils import is_torch_available, is_vision_available from ...test_processing_common import ProcessorTesterMixin @@ -35,6 +35,7 @@ if is_vision_available(): @require_vision +@require_read_token class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase): """This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""