mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[MllamaProcessor
] Update errors and API with multiple image (#33715)
* update error * update and add a test * update * update
This commit is contained in:
parent
0a21381ba3
commit
46841d3eb2
@ -12,11 +12,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor class for Mllama.
|
||||
"""
|
||||
|
||||
from statistics import mean
|
||||
"""Processor class for Mllama."""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -296,25 +294,27 @@ class MllamaProcessor(ProcessorMixin):
|
||||
encoding = self.tokenizer(text, **text_kwargs)
|
||||
data.update(encoding)
|
||||
|
||||
n_images_in_images = [0]
|
||||
if images is not None:
|
||||
images = make_list_of_images(images)
|
||||
n_images_in_images = [len(sample) for sample in images]
|
||||
|
||||
if text is not None:
|
||||
if (
|
||||
not all(batch_img_per_prompt == n_images_in_images for batch_img_per_prompt in n_images_in_text)
|
||||
and len(text) > 1
|
||||
):
|
||||
if text is not None:
|
||||
if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
|
||||
batch_img == 0 for batch_img in n_images_in_text
|
||||
):
|
||||
raise ValueError(
|
||||
"If a batch of text is provided, there should be either no images or at least one image per sample"
|
||||
)
|
||||
if sum(n_images_in_images) != sum(n_images_in_text):
|
||||
if images is None:
|
||||
raise ValueError("No image were provided, but there are image tokens in the prompt")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The number of images in each batch {n_images_in_text} should be the same {n_images_in_images} should be the same. Yes, the model does not \
|
||||
support having a different number of images per batch."
|
||||
)
|
||||
if int(mean(n_images_in_text)) != int(mean(n_images_in_images)):
|
||||
raise ValueError(
|
||||
f"The number of images in the text ({n_images_in_text}) should be the same as in the number of provided images ({n_images_in_images}) \
|
||||
should be the same."
|
||||
f"The number of image token ({sum(n_images_in_images)}) should be the same as in the number of provided images ({sum(n_images_in_images)})"
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
image_features = self.image_processor(images, **images_kwargs)
|
||||
num_tiles = image_features.pop("num_tiles")
|
||||
data.update(image_features)
|
||||
|
@ -15,6 +15,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import MllamaProcessor
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
@ -177,3 +179,119 @@ class MllamaProcessorTest(unittest.TestCase):
|
||||
rendered_list = self.processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
|
||||
rendered_str = self.processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(rendered_list, rendered_str)
|
||||
|
||||
def test_process_interleaved_images_prompts_image_splitting(self):
|
||||
# Test that a single image is processed correctly
|
||||
inputs = self.processor(images=self.image2, size={"width": 224, "height": 224})
|
||||
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 224, 224))
|
||||
|
||||
# Test that text is processed correctly
|
||||
text = "<|begin_of_text|>This is a test sentence.<|end_of_text|>"
|
||||
inputs = self.processor(text=text)
|
||||
expected_ids = [128000, 2028, 374, 264, 1296, 11914, 13, 128001]
|
||||
self.assertEqual(inputs["input_ids"][0], expected_ids)
|
||||
self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids))
|
||||
self.assertEqual(inputs.get("cross_attention_mask"), None)
|
||||
|
||||
# Test a single sample with image and text
|
||||
image_str = "<|image|>"
|
||||
text_str = "This is a test sentence."
|
||||
text = image_str + text_str
|
||||
inputs = self.processor(
|
||||
text=text,
|
||||
images=self.image1,
|
||||
size={"width": 128, "height": 128},
|
||||
)
|
||||
expected_ids = [self.image_token_id, self.bos_token_id] + [2028, 374, 264, 1296, 11914, 13]
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 128, 128))
|
||||
self.assertEqual(inputs["input_ids"][0], expected_ids)
|
||||
self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids))
|
||||
cross_attention_mask = inputs["cross_attention_mask"]
|
||||
self.assertEqual(cross_attention_mask.shape, (1, 8, 1, 4))
|
||||
self.assertTrue(
|
||||
np.all(cross_attention_mask == 1), f"Cross attention mask is not all ones: {cross_attention_mask}"
|
||||
)
|
||||
|
||||
# Test batch
|
||||
text = [
|
||||
"<|image|>This is a test sentence.",
|
||||
"This is a test sentence.<|image|><|image|>This is a test sentence.",
|
||||
]
|
||||
# fmt: off
|
||||
expected_ids = [
|
||||
[self.image_token_id, self.bos_token_id, 2028, 374, 264, 1296, 11914, 13],
|
||||
[self.bos_token_id, 2028, 374, 264, 1296, 11914, 13, self.image_token_id, self.image_token_id, 2028, 374, 264, 1296, 11914, 13],
|
||||
]
|
||||
# fmt: onn
|
||||
images = [[self.image1], [self.image1, self.image2]]
|
||||
inputs = self.processor(text=text, images=images, padding=True, size={"width": 256, "height": 256})
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape, (2, 2, 4, 3, 256, 256))
|
||||
for input_ids_i, attention_mask_i, expected_ids_i in zip(inputs["input_ids"], inputs["attention_mask"], expected_ids):
|
||||
pad_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 0]
|
||||
input_ids = [id for id, m in zip(input_ids_i, attention_mask_i) if m == 1]
|
||||
self.assertEqual(input_ids, expected_ids_i)
|
||||
self.assertEqual(pad_ids, [self.pad_token_id] * len(pad_ids))
|
||||
|
||||
cross_attention_mask = inputs["cross_attention_mask"]
|
||||
self.assertEqual(cross_attention_mask.shape, (2, 15, 2, 4))
|
||||
|
||||
# Check that only first tile of first sample is attended to all text tokens
|
||||
first_sample_mask = cross_attention_mask[0].copy()
|
||||
first_image_first_tile_attention = first_sample_mask[:, :1, :1] # text tokens, images, tiles
|
||||
self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}")
|
||||
|
||||
# zero out first tile of first image
|
||||
first_image_first_tile_attention[:, :1, :1] = 0
|
||||
self.assertTrue(np.all(first_image_first_tile_attention == 0), f"Cross attention mask is not all zeros: {first_image_first_tile_attention}")
|
||||
|
||||
# second sample
|
||||
second_sample_mask = cross_attention_mask[1].copy()
|
||||
first_image_first_tile_attention = second_sample_mask[7:, :1, :1] # text tokens, images, tiles
|
||||
self.assertTrue(np.all(first_image_first_tile_attention == 1), f"Cross attention mask is not all ones: {first_image_first_tile_attention}")
|
||||
|
||||
second_image_two_tiles_attention = second_sample_mask[8:, 1:2, :2] # text tokens, images, tiles
|
||||
self.assertTrue(np.all(second_image_two_tiles_attention == 1), f"Cross attention mask is not all ones: {second_image_two_tiles_attention}")
|
||||
|
||||
# zero out both images masks
|
||||
second_sample_mask[7:, :1, :1] = 0
|
||||
second_sample_mask[8:, 1:2, :2] = 0
|
||||
self.assertTrue(np.all(second_sample_mask == 0), f"Cross attention mask is not all zeros: {second_sample_mask}")
|
||||
|
||||
def test_process_interleaved_images_prompts_image_error(self):
|
||||
text = [
|
||||
"This is a test sentence.",
|
||||
"In this other sentence we try some good things",
|
||||
]
|
||||
inputs = self.processor(text=text, images=None, padding=True)
|
||||
self.assertIsNotNone(inputs["input_ids"])
|
||||
|
||||
text = [
|
||||
"This is a test sentence.<|image|>",
|
||||
"In this other sentence we try some good things",
|
||||
]
|
||||
with self.assertRaises(ValueError):
|
||||
self.processor(text=text, images=None, padding=True)
|
||||
|
||||
images = [[self.image1], []]
|
||||
with self.assertRaises(ValueError):
|
||||
self.processor(text=text, images=images, padding=True)
|
||||
|
||||
text = [
|
||||
"This is a test sentence.<|image|>",
|
||||
"In this other sentence we try some good things<|image|>",
|
||||
]
|
||||
with self.assertRaises(ValueError):
|
||||
self.processor(text=text, images=None, padding=True)
|
||||
|
||||
text = [
|
||||
"This is a test sentence.<|image|>",
|
||||
"In this other sentence we try some good things<|image|>",
|
||||
]
|
||||
images = [[self.image1], [self.image2]]
|
||||
inputs = self.processor(text=text, images=images, padding=True)
|
||||
|
||||
images = [[self.image1, self.image2], []]
|
||||
with self.assertRaises(ValueError):
|
||||
self.processor(text=text, images=None, padding=True)
|
||||
|
Loading…
Reference in New Issue
Block a user