mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[chameleon] fix num image token check (#36918)
* [chameleon] fix num image token check * embed after merging image token * skip this also * mistral require_read_token
This commit is contained in:
parent
a41e08aa19
commit
57f551c78d
@ -1289,13 +1289,10 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_tokens = self.get_image_tokens(pixel_values)
|
||||
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
|
||||
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
|
||||
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
|
||||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
|
||||
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
|
||||
raise ValueError(
|
||||
@ -1304,6 +1301,9 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
|
||||
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# torch.jit.trace() doesn't support cache objects in the output
|
||||
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
||||
past_key_values = DynamicCache()
|
||||
|
@ -126,6 +126,7 @@ VLM_CLASS_NAMES = [
|
||||
"ayavision",
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"chameleon",
|
||||
]
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch chameleon model."""
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
@ -30,7 +31,7 @@ from transformers.testing_utils import (
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
@ -52,12 +53,12 @@ class ChameleonModelTester:
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
seq_length=35,
|
||||
is_training=False,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
image_token_id=98,
|
||||
image_token_id=4,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
@ -73,9 +74,9 @@ class ChameleonModelTester:
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
pad_token_id=0,
|
||||
vq_num_embeds=12,
|
||||
vq_embed_dim=12,
|
||||
vq_channel_multiplier=[1, 2],
|
||||
vq_num_embeds=5,
|
||||
vq_embed_dim=5,
|
||||
vq_channel_multiplier=[1, 4],
|
||||
vq_img_token_start_id=10, # has to be less than vocab size when added with vq_num_embeds
|
||||
scope=None,
|
||||
):
|
||||
@ -138,7 +139,9 @@ class ChameleonModelTester:
|
||||
start = self.vq_img_token_start_id
|
||||
end = self.vq_img_token_start_id + self.vq_num_embeds
|
||||
for i in range(start, end):
|
||||
vocab_map[i] = f"IMGIMGBS{i}" # dummy str for each token, anything starting with IMGIMG
|
||||
image_token_infix = "".join(chr(ord("A") + int(c)) for c in str(i))
|
||||
# dummy str for each image token, anything starting with IMGIMG
|
||||
vocab_map[i] = f"IMGIMG{image_token_infix}Z"
|
||||
|
||||
return ChameleonConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
@ -275,7 +278,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
{
|
||||
"feature-extraction": ChameleonModel,
|
||||
"text-generation": ChameleonForConditionalGeneration,
|
||||
"image-text-to-text": ChameleonForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
@ -330,6 +332,149 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_batching_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code")
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
|
||||
class ChameleonVision2SeqModelTester(ChameleonModelTester):
|
||||
def __init__(self, parent, image_size=10, **kwargs):
|
||||
super().__init__(parent, **kwargs)
|
||||
self.image_size = image_size
|
||||
self.image_seq_length = 25
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[:, : self.image_seq_length] = self.image_token_id
|
||||
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
||||
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"image-text-to-text": ChameleonForConditionalGeneration,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ChameleonVision2SeqModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ChameleonConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@unittest.skip("Chameleon forces some token ids to be -inf!")
|
||||
def test_batching_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code")
|
||||
def test_model_is_small(self):
|
||||
pass
|
||||
|
||||
def test_mismatching_num_image_tokens(self):
|
||||
"""
|
||||
Tests that VLMs through an error with explicit message saying what is wrong
|
||||
when number of images don't match number of image tokens in the text.
|
||||
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
||||
"""
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
curr_input_dict = copy.deepcopy(input_dict) # the below tests modify dict in-place
|
||||
_ = model(**curr_input_dict) # successful forward with no modifications
|
||||
|
||||
# remove one image but leave the image token in text
|
||||
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(**curr_input_dict)
|
||||
|
||||
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
||||
input_ids = curr_input_dict["input_ids"][:1]
|
||||
pixel_values = curr_input_dict["pixel_values"][:1]
|
||||
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
||||
|
||||
# one image and two image tokens raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# two images and two image tokens don't raise an error
|
||||
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
||||
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
inputs["inputs_embeds"] = wte(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
|
||||
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
||||
# while some other models require pixel_values to be present
|
||||
def test_inputs_embeds_matches_input_ids(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
del inputs["pixel_values"]
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
out_ids = model(input_ids=input_ids, **inputs)[0]
|
||||
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
||||
torch.testing.assert_close(out_embeds, out_ids)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ChameleonIntegrationTest(unittest.TestCase):
|
||||
|
@ -20,7 +20,7 @@ import unittest
|
||||
import requests
|
||||
|
||||
from transformers import PixtralProcessor
|
||||
from transformers.testing_utils import require_vision
|
||||
from transformers.testing_utils import require_read_token, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
@ -35,6 +35,7 @@ if is_vision_available():
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_read_token
|
||||
class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
"""This tests Pixtral processor with the new `spatial_merge_size` argument in Mistral3."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user