mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00

* Iterative generation using input embeds * Add Janus model * discard changes * Janus imports * Refactor config and processor * Added Vision tower of Janus * Import Janus Image processor * Vision tower fixes * Refactor code * Added VQ Model * Complete model integration * temp conversion script * processor refactor * Adding files to facilitate pulling * Fixes after debugging * Skip test for these models * Add Janus Model * discard changes * Janus imports * Refactor config and processor * Added Vision tower of Janus * Import Janus Image processor * Vision tower fixes * Refactor code * Added VQ Model * Complete model integration * temp conversion script * processor refactor * Adding files to facilitate pulling * Fixes after debugging * Refactor to Text config * ✨ Added generate function * Saving intermediate convert file. Still need to read configs from the hub and convert them to our format. * Adding version that reads from the JSON files. Still have to tweak some parameters manually. * relative imports * Initial tests * Refactor image processor * Seemingly working version of the conversion script, will need to test further. * Adding command message * Fixing conflicting JanusTextConfig class * Incorporating some of the discussed changes. * Small fix to create dir. * Removing system from JINJA template * Adding draft processor tests * style fixes * Minor fixes and enhancement * added generation config * Initial tests * Small modifications, tests are now passing. * Small changes I noticed while reading code. * more fixes * Added JanusModel class * Small merge adaptations * Small merge adaptations * Image processing tests passing * More tests and fixes * Convert script updated and refactored * Tests and cleanup * make style * Postprocessing for image generation * generate refactor * fixes * - Passing tests that write a part of the model to cpu (e.g. test_cpu_offload) - Passing tests of dispatching SDPA - Only gradient checkpointing tests are left. * Removing temporary code * Changes * Writing change to modular * Added JanusVisionModel. SDPA dispatch tests pass more robustly. Gradient checkpoint tests are next * Gradient checkpoint tests passing * Removing debug code * Major generate refactor 😮💨 * Temp changes for testing * Green quality CI * 2 out of 4 integration tests passing * breadcrumbs * Usage Examples * Regenerate modeling after merge * dirty code * JanusIntegrationTest are passing * breadcrumbs * happy CI * fixes * Changing template * nits * Text generation logits matching original codebase at 100% precision * Remove ./tmp from git tracking * Remove ./tmp from git tracking * Checkpointing changes after reviewing * Fixing code in docstrings * CHanging comments and small bug in convert file * Fixing bug in image_token_id for 7B version * Removing line that was added by both of us * Pushing changes after discussion. Only one left is to change the key mapping for convert file. * Updating module file * New convert file using dict. Tested that it is equivalent to the old one by: - comparing keys in a script - comparing checksums of the output files between version generated with the current convert script and those generated with the old script. This is a more reliable test. * revert changes * mistake * consistency change for CI * make style * doc fixes * more fixes * experimenting with masking out pad token * checkpoint * Batched generation with multi-images working for 1B models. Will test 7B next. * Device fix. * Writing changes to modular, previous ones were written to modeling just for quick testing. * Using passed processor attention mask (only in modeling for now) * Matching performance done in the non-standard way * Working version of batched generation. Will change how some args are passed to make it more similar to language case * More compliant version of the code * Removed duplicated `_prepare_4d_causal_attention_mask_with_cache_position` * Updating modular file, making masked filling with paddings more efficient * Slightly more efficient version * Modifying JanusVisionModel to be a wrapper * Fixing test to comply with new names * Modular overhaul * More refactoring * - Changing JanusVisionModel back - Changing forward pass - Adding boi token to the comparison * - Removing whole context model_ids - Using inherited implementation of prepare_inputs_for_generation * Moving the way boi token is passed to the model * Fixing sdpa test * Minor changes * testing changes * Minor fix * - Adding postprocessing test - checking values of generated image on integration test * changes * Removing pooled attention vision module, fixing convert script as a consequence * More changes * Fixes * Draft after merge * Bug fixes * More bug fix * Fixing docs * Nits * Refactor return dict * Moving image post processing test to main processor post process * Passing guidance_scale as kwarg * make style * 🔥 refactor * make style * Update and green CI * Nits and tests update * up * Added MID block * fix * Dead code * update testcase * update * model_id change * init_weight changes --------- Co-authored-by: hsilva664 <metallic-silver@hotmail.com>
456 lines
18 KiB
Python
456 lines
18 KiB
Python
# coding=utf-8
|
||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""Testing suite for the PyTorch Janus model."""
|
||
|
||
import tempfile
|
||
import unittest
|
||
|
||
import numpy as np
|
||
|
||
from transformers import AutoProcessor, AutoTokenizer, JanusProcessor
|
||
from transformers.models.janus.convert_janus_weights_to_hf import CHAT_TEMPLATE
|
||
from transformers.utils import is_vision_available
|
||
|
||
from ...test_processing_common import ProcessorTesterMixin
|
||
|
||
|
||
if is_vision_available():
|
||
pass
|
||
|
||
|
||
class JanusProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||
processor_class = JanusProcessor
|
||
|
||
def setUp(self):
|
||
self.tmpdirname = tempfile.mkdtemp()
|
||
special_image_tokens = {
|
||
"image_token": "<image_placeholder>",
|
||
"boi_token": "<begin_of_image>",
|
||
"eoi_token": "<end_of_image>",
|
||
}
|
||
|
||
processor = self.processor_class.from_pretrained(
|
||
"deepseek-community/Janus-Pro-1B",
|
||
extra_special_tokens=special_image_tokens,
|
||
)
|
||
processor.save_pretrained(self.tmpdirname)
|
||
|
||
def get_tokenizer(self, **kwargs):
|
||
return AutoTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||
|
||
def get_image_processor(self, **kwargs):
|
||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||
|
||
def prepare_processor_dict(self):
|
||
# similar to Emu3 and Qwen2VLProcessorTest, but keep the template in the convert script to avoid duplicated code
|
||
return {
|
||
"chat_template": CHAT_TEMPLATE,
|
||
}
|
||
|
||
def test_chat_template_single(self):
|
||
"""
|
||
Tests that the chat template matches the original implementation when applied to a single message.
|
||
"""
|
||
processor = self.get_processor()
|
||
if processor.chat_template is None:
|
||
self.skipTest("Processor has no chat template")
|
||
|
||
# Single image message
|
||
messages = [
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "What is shown in this image?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
]
|
||
]
|
||
|
||
correct_prompt = ["<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:"]
|
||
|
||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||
self.assertEqual(formatted_prompt, correct_prompt)
|
||
|
||
# Single image message with capitalization
|
||
messages = [
|
||
[
|
||
{
|
||
"role": "User",
|
||
"content": [
|
||
{"type": "text", "text": "What is shown in this image?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
]
|
||
]
|
||
|
||
correct_prompt = ["<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:"]
|
||
|
||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||
self.assertEqual(formatted_prompt, correct_prompt)
|
||
|
||
# Single image message with uppercase
|
||
messages = [
|
||
[
|
||
{
|
||
"role": "USER",
|
||
"content": [
|
||
{"type": "text", "text": "What is shown in this image?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
]
|
||
]
|
||
|
||
correct_prompt = ["<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:"]
|
||
|
||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||
self.assertEqual(formatted_prompt, correct_prompt)
|
||
|
||
"""
|
||
Warning: normally, the other models have a test comparing chat template+tokenization as two separate steps
|
||
versus as a single step (i.e. processor.apply_chat_template(..., tokenize=True)). However, our processor has
|
||
some extra steps other than simply applying prompt to tokenizer. These include prepending the default system
|
||
prompts and, following the implementation from the Janus codebase, expanding the image token.
|
||
"""
|
||
|
||
# Checking the output dict keys
|
||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||
|
||
# Now test the ability to return dict
|
||
messages[0][0]["content"][1].update(
|
||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||
)
|
||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||
self.assertTrue(self.images_input_name in out_dict)
|
||
# should always have input_ids and attention_mask
|
||
self.assertEqual(len(out_dict["input_ids"]), 1)
|
||
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||
self.assertEqual(len(out_dict[self.images_input_name]), 1)
|
||
|
||
# Passing generation prompt explicitly
|
||
messages = [
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "What is shown in this image?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
{
|
||
"role": "assistant",
|
||
"content": [
|
||
{"type": "text", "text": ""},
|
||
],
|
||
},
|
||
]
|
||
]
|
||
|
||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=False)
|
||
self.assertEqual(formatted_prompt, correct_prompt)
|
||
|
||
# Single prompt with multiple images
|
||
messages = [
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "Compare this image"},
|
||
{"type": "image"},
|
||
{"type": "text", "text": "with this image"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
]
|
||
]
|
||
|
||
correct_prompt = [
|
||
"<|User|>: Compare this image\n<image_placeholder>\nwith this image\n<image_placeholder>\n\n<|Assistant|>:"
|
||
]
|
||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||
self.assertEqual(formatted_prompt, correct_prompt)
|
||
|
||
# Multiple turns and multiple images
|
||
messages = [
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "Compare this image"},
|
||
{"type": "image"},
|
||
{"type": "text", "text": "with this image"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
{
|
||
"role": "assistant",
|
||
"content": [
|
||
{"type": "text", "text": "The first image is an equation, the second is a pie chart."},
|
||
],
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image"},
|
||
{
|
||
"type": "text",
|
||
"text": "What about this third image? To which of the previous to is it more similar?",
|
||
},
|
||
],
|
||
},
|
||
]
|
||
]
|
||
|
||
correct_prompt = [
|
||
"<|User|>: Compare this image\n<image_placeholder>\nwith this image\n<image_placeholder>\n\n<|Assistant|>: The first image is an equation, the second is a pie chart.<|end▁of▁sentence|><|User|>: <image_placeholder>\nWhat about this third image? To which of the previous to is it more similar?\n\n<|Assistant|>:"
|
||
]
|
||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||
self.assertEqual(formatted_prompt, correct_prompt)
|
||
|
||
def test_chat_template_batched(self):
|
||
"""
|
||
Tests that the chat template matches the original implementation when applied to a batch of messages.
|
||
"""
|
||
processor = self.get_processor()
|
||
if processor.chat_template is None:
|
||
self.skipTest("Processor has no chat template")
|
||
|
||
# Test 1: Simple single image per message batch
|
||
batched_messages = [
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "What is shown in this image?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
],
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "What is shown in this image?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
],
|
||
]
|
||
|
||
correct_prompts = [
|
||
"<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:",
|
||
"<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:",
|
||
]
|
||
|
||
formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
|
||
self.assertEqual(formatted_prompts, correct_prompts)
|
||
|
||
# Similarly to the single case, no test for chat template+tokenization as two separate steps versus as a single step
|
||
|
||
# Checking the output dict keys
|
||
out_dict = processor.apply_chat_template(
|
||
batched_messages,
|
||
add_generation_prompt=True,
|
||
tokenize=True,
|
||
return_dict=True,
|
||
padding=True,
|
||
)
|
||
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
|
||
|
||
# Verify image inputs are included in the output dict
|
||
batched_messages[0][0]["content"][1].update(
|
||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||
)
|
||
batched_messages[1][0]["content"][1].update(
|
||
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
|
||
)
|
||
out_dict = processor.apply_chat_template(
|
||
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
|
||
)
|
||
self.assertTrue(self.images_input_name in out_dict)
|
||
self.assertEqual(len(out_dict["input_ids"]), 2) # Batch size for text
|
||
self.assertEqual(len(out_dict["attention_mask"]), 2) # Batch size for attention mask
|
||
self.assertEqual(len(out_dict[self.images_input_name]), 2) # Batch size for images
|
||
|
||
# Test 2: Two images per message batch with different prompts
|
||
batched_messages = [
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "Compare this image"},
|
||
{"type": "image"},
|
||
{"type": "text", "text": "with this image"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
],
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image"},
|
||
{"type": "text", "text": "Describe how the previous image compares to the following"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
],
|
||
]
|
||
|
||
correct_prompts = [
|
||
"<|User|>: Compare this image\n<image_placeholder>\nwith this image\n<image_placeholder>\n\n<|Assistant|>:",
|
||
"<|User|>: <image_placeholder>\nDescribe how the previous image compares to the following\n<image_placeholder>\n\n<|Assistant|>:",
|
||
]
|
||
formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
|
||
self.assertEqual(formatted_prompts, correct_prompts)
|
||
|
||
# Test 3: Multi-turn conversations with multiple images
|
||
batched_messages = [
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "Compare this image"},
|
||
{"type": "image"},
|
||
{"type": "text", "text": "with this image"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
{
|
||
"role": "assistant",
|
||
"content": [
|
||
{"type": "text", "text": "The first image is an equation, the second is a pie chart."},
|
||
],
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image"},
|
||
{
|
||
"type": "text",
|
||
"text": "What about this third image? To which of the previous to is it more similar?",
|
||
},
|
||
],
|
||
},
|
||
],
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "image"},
|
||
{"type": "text", "text": "Describe how the previous image compares to the following"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
{
|
||
"role": "assistant",
|
||
"content": [
|
||
{"type": "text", "text": "The first image is a formula, the second is a plot."},
|
||
],
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "Which of them is closer to the following?"},
|
||
{"type": "image"},
|
||
],
|
||
},
|
||
],
|
||
]
|
||
|
||
correct_prompts = [
|
||
"<|User|>: Compare this image\n<image_placeholder>\nwith this image\n<image_placeholder>\n\n<|Assistant|>: The first image is an equation, the second is a pie chart.<|end▁of▁sentence|><|User|>: <image_placeholder>\nWhat about this third image? To which of the previous to is it more similar?\n\n<|Assistant|>:",
|
||
"<|User|>: <image_placeholder>\nDescribe how the previous image compares to the following\n<image_placeholder>\n\n<|Assistant|>: The first image is a formula, the second is a plot.<|end▁of▁sentence|><|User|>: Which of them is closer to the following?\n<image_placeholder>\n\n<|Assistant|>:",
|
||
]
|
||
formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
|
||
self.assertEqual(formatted_prompts, correct_prompts)
|
||
|
||
def test_chat_template_accepts_processing_kwargs(self):
|
||
"""Tests that the chat template correctly handles additional processing arguments."""
|
||
# Get processor and skip if it doesn't have a chat template
|
||
processor = self.get_processor()
|
||
if processor.chat_template is None:
|
||
self.skipTest("Processor has no chat template")
|
||
|
||
# Create a simple text message for testing
|
||
messages = [
|
||
[
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": "What is shown in this image?"},
|
||
],
|
||
},
|
||
]
|
||
]
|
||
|
||
# Test 1: Padding to max_length
|
||
# PS: we have to override the parent max_length of 50 to 80 because the output is already 51 tokens
|
||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||
messages,
|
||
add_generation_prompt=True,
|
||
tokenize=True,
|
||
padding="max_length",
|
||
max_length=80,
|
||
)
|
||
self.assertEqual(len(formatted_prompt_tokenized[0]), 80)
|
||
|
||
# Test 2: Truncation
|
||
# Verify that the output is truncated to exactly 5 tokens
|
||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||
messages,
|
||
add_generation_prompt=True,
|
||
tokenize=True,
|
||
truncation=True,
|
||
max_length=5,
|
||
)
|
||
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
|
||
|
||
# Test 3: Image processing kwargs
|
||
# Add an image and test image processing parameters
|
||
messages[0][0]["content"].append(
|
||
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||
)
|
||
# Process with image rescaling and verify the pixel values are negative
|
||
out_dict = processor.apply_chat_template(
|
||
messages,
|
||
add_generation_prompt=True,
|
||
tokenize=True,
|
||
return_dict=True,
|
||
do_rescale=True,
|
||
rescale_factor=-1,
|
||
return_tensors="np",
|
||
)
|
||
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
|
||
|
||
def test_processor_postprocess(self):
|
||
processor_components = self.prepare_components()
|
||
processor = self.processor_class(**processor_components)
|
||
|
||
input_str = "lower newer"
|
||
orig_image_input = self.prepare_image_inputs()
|
||
orig_image = np.array(orig_image_input).transpose(2, 0, 1)
|
||
|
||
inputs = processor(text=input_str, images=orig_image, do_resize=False, return_tensors="np")
|
||
normalized_image_input = inputs.pixel_values
|
||
unnormalized_images = processor.postprocess(normalized_image_input, return_tensors="np")["pixel_values"]
|
||
|
||
# For an image where pixels go from 0 to 255 the diff can be 1 due to some numerical precision errors when scaling and unscaling
|
||
self.assertTrue(np.abs(orig_image - unnormalized_images).max() >= 1)
|