[chameleon] fix num image token check (#36918)

* [chameleon] fix num image token check

* embed after merging image token

* skip this also

* mistral require_read_token
This commit is contained in:
Raushan Turganbay 2025-03-24 12:36:08 +01:00 committed by GitHub
parent a41e08aa19
commit 57f551c78d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 160 additions and 13 deletions

View File

@ -1289,13 +1289,10 @@ class ChameleonModel(ChameleonPreTrainedModel):
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
) )
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if pixel_values is not None: if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values) image_tokens = self.get_image_tokens(pixel_values)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel(): if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1] n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
raise ValueError( raise ValueError(
@ -1304,6 +1301,9 @@ class ChameleonModel(ChameleonPreTrainedModel):
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# torch.jit.trace() doesn't support cache objects in the output # torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing(): if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache() past_key_values = DynamicCache()

View File

@ -126,6 +126,7 @@ VLM_CLASS_NAMES = [
"ayavision", "ayavision",
"gemma3", "gemma3",
"mistral3", "mistral3",
"chameleon",
] ]

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch chameleon model.""" """Testing suite for the PyTorch chameleon model."""
import copy
import unittest import unittest
import requests import requests
@ -30,7 +31,7 @@ from transformers.testing_utils import (
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin from ...test_pipeline_mixin import PipelineTesterMixin
@ -52,12 +53,12 @@ class ChameleonModelTester:
self, self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, seq_length=35,
is_training=False, is_training=False,
use_input_mask=True, use_input_mask=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
image_token_id=98, image_token_id=4,
hidden_size=32, hidden_size=32,
num_hidden_layers=2, num_hidden_layers=2,
num_attention_heads=2, num_attention_heads=2,
@ -73,9 +74,9 @@ class ChameleonModelTester:
num_labels=3, num_labels=3,
num_choices=4, num_choices=4,
pad_token_id=0, pad_token_id=0,
vq_num_embeds=12, vq_num_embeds=5,
vq_embed_dim=12, vq_embed_dim=5,
vq_channel_multiplier=[1, 2], vq_channel_multiplier=[1, 4],
vq_img_token_start_id=10, # has to be less than vocab size when added with vq_num_embeds vq_img_token_start_id=10, # has to be less than vocab size when added with vq_num_embeds
scope=None, scope=None,
): ):
@ -138,7 +139,9 @@ class ChameleonModelTester:
start = self.vq_img_token_start_id start = self.vq_img_token_start_id
end = self.vq_img_token_start_id + self.vq_num_embeds end = self.vq_img_token_start_id + self.vq_num_embeds
for i in range(start, end): for i in range(start, end):
vocab_map[i] = f"IMGIMGBS{i}" # dummy str for each token, anything starting with IMGIMG image_token_infix = "".join(chr(ord("A") + int(c)) for c in str(i))
# dummy str for each image token, anything starting with IMGIMG
vocab_map[i] = f"IMGIMG{image_token_infix}Z"
return ChameleonConfig( return ChameleonConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
@ -275,7 +278,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
{ {
"feature-extraction": ChameleonModel, "feature-extraction": ChameleonModel,
"text-generation": ChameleonForConditionalGeneration, "text-generation": ChameleonForConditionalGeneration,
"image-text-to-text": ChameleonForConditionalGeneration,
} }
if is_torch_available() if is_torch_available()
else {} else {}
@ -330,6 +332,149 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def test_batching_equivalence(self): def test_batching_equivalence(self):
pass pass
@unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code")
def test_model_is_small(self):
pass
class ChameleonVision2SeqModelTester(ChameleonModelTester):
def __init__(self, parent, image_size=10, **kwargs):
super().__init__(parent, **kwargs)
self.image_size = image_size
self.image_seq_length = 25
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[:, : self.image_seq_length] = self.image_token_id
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
pipeline_model_mapping = (
{
"image-text-to-text": ChameleonForConditionalGeneration,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = False
def setUp(self):
self.model_tester = ChameleonVision2SeqModelTester(self)
self.config_tester = ConfigTester(self, config_class=ChameleonConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip("Chameleon forces some token ids to be -inf!")
def test_batching_equivalence(self):
pass
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
def test_cpu_offload(self):
pass
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
def test_disk_offload_bin(self):
pass
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
def test_disk_offload_safetensors(self):
pass
@unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code")
def test_model_is_small(self):
pass
def test_mismatching_num_image_tokens(self):
"""
Tests that VLMs through an error with explicit message saying what is wrong
when number of images don't match number of image tokens in the text.
Also we need to test multi-image cases when one prompr has multiple image tokens.
"""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
curr_input_dict = copy.deepcopy(input_dict) # the below tests modify dict in-place
_ = model(**curr_input_dict) # successful forward with no modifications
# remove one image but leave the image token in text
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
with self.assertRaises(ValueError):
_ = model(**curr_input_dict)
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
input_ids = curr_input_dict["input_ids"][:1]
pixel_values = curr_input_dict["pixel_values"][:1]
input_ids = torch.cat([input_ids, input_ids], dim=0)
# one image and two image tokens raise an error
with self.assertRaises(ValueError):
_ = model(input_ids=input_ids, pixel_values=pixel_values)
# two images and two image tokens don't raise an error
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
@require_torch @require_torch
class ChameleonIntegrationTest(unittest.TestCase): class ChameleonIntegrationTest(unittest.TestCase):

View File

@ -20,7 +20,7 @@ import unittest
import requests import requests
from transformers import PixtralProcessor from transformers import PixtralProcessor
from transformers.testing_utils import require_vision from transformers.testing_utils import require_read_token, require_vision
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...test_processing_common import ProcessorTesterMixin from ...test_processing_common import ProcessorTesterMixin
@ -35,6 +35,7 @@ if is_vision_available():
@require_vision @require_vision
@require_read_token
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase): class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3.""" """This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""