Add Janus model (#36053)

* Iterative generation using input embeds

* Add Janus model

* discard changes

* Janus imports

* Refactor config and processor

* Added Vision tower of Janus

* Import Janus Image processor

* Vision tower fixes

* Refactor code

* Added VQ Model

* Complete model integration

* temp conversion script

* processor refactor

* Adding files to facilitate pulling

* Fixes after debugging

* Skip test for these models

* Add Janus Model

* discard changes

* Janus imports

* Refactor config and processor

* Added Vision tower of Janus

* Import Janus Image processor

* Vision tower fixes

* Refactor code

* Added VQ Model

* Complete model integration

* temp conversion script

* processor refactor

* Adding files to facilitate pulling

* Fixes after debugging

* Refactor to Text config

*  Added generate function

* Saving intermediate convert file. Still need to read configs from the hub and convert them to our format.

* Adding version that reads from the JSON files. Still have to tweak some parameters manually.

* relative imports

* Initial tests

* Refactor image processor

* Seemingly working version of the conversion script, will need to test further.

* Adding command message

* Fixing conflicting JanusTextConfig class

* Incorporating some of the discussed changes.

* Small fix to create dir.

* Removing system from JINJA template

* Adding draft processor tests

* style fixes

* Minor fixes and enhancement

* added generation config

* Initial tests

* Small modifications, tests are now passing.

* Small changes I noticed while reading code.

* more fixes

* Added JanusModel class

* Small merge adaptations

* Small merge adaptations

* Image processing tests passing

* More tests and fixes

* Convert script updated and refactored

* Tests and cleanup

* make style

* Postprocessing for image generation

* generate refactor

* fixes

* - Passing tests that write a part of the model to cpu (e.g. test_cpu_offload)
- Passing tests of dispatching SDPA
- Only gradient checkpointing tests are left.

* Removing temporary code

* Changes

* Writing change to modular

* Added JanusVisionModel. SDPA dispatch tests pass more robustly. Gradient checkpoint tests are next

* Gradient checkpoint tests passing

* Removing debug code

* Major generate refactor 😮‍💨

* Temp changes for testing

* Green quality CI

* 2 out of 4 integration tests passing

* breadcrumbs

* Usage Examples

* Regenerate modeling after merge

* dirty code

* JanusIntegrationTest are passing

* breadcrumbs

* happy CI

* fixes

* Changing template

* nits

* Text generation logits matching original codebase at 100% precision

* Remove ./tmp from git tracking

* Remove ./tmp from git tracking

* Checkpointing changes after reviewing

* Fixing code in docstrings

* CHanging comments and small bug in convert file

* Fixing bug in image_token_id for 7B version

* Removing line that was added by both of us

* Pushing changes after discussion. Only one left is to change the key mapping for convert file.

* Updating module file

* New convert file using dict. Tested that it is equivalent to the old one by:
- comparing keys in a script
- comparing checksums of the output files between version generated with the current convert script and those generated with the old script. This is a more reliable test.

* revert changes

* mistake

* consistency change for CI

* make style

* doc fixes

* more fixes

* experimenting with masking out pad token

* checkpoint

* Batched generation with multi-images working for 1B models. Will test 7B next.

* Device fix.

* Writing changes to modular, previous ones were written to modeling just for quick testing.

* Using passed processor attention mask (only in modeling for now)

* Matching performance done in the non-standard way

* Working version of batched generation. Will change how some args are passed to make it more similar to language case

* More compliant version of the code

* Removed duplicated `_prepare_4d_causal_attention_mask_with_cache_position`

* Updating modular file, making masked filling with paddings more efficient

* Slightly more efficient version

* Modifying JanusVisionModel to be a wrapper

* Fixing test to comply with new names

* Modular overhaul

* More refactoring

* - Changing JanusVisionModel back
- Changing forward pass
- Adding boi token to the comparison

* - Removing whole context model_ids
- Using inherited implementation of prepare_inputs_for_generation

* Moving the way boi token is passed to the model

* Fixing sdpa test

* Minor changes

* testing changes

* Minor fix

* - Adding postprocessing test
- checking values of generated image on integration test

* changes

* Removing pooled attention vision module, fixing convert script as a consequence

* More changes

* Fixes

* Draft after merge

* Bug fixes

* More bug fix

* Fixing docs

* Nits

* Refactor return dict

* Moving image post processing test to main processor post process

* Passing guidance_scale as kwarg

* make style

* 🔥 refactor

* make style

* Update and green CI

* Nits and tests update

* up

* Added MID block

* fix

* Dead code

* update testcase

* update

* model_id change

* init_weight changes

---------

Co-authored-by: hsilva664 <metallic-silver@hotmail.com>
This commit is contained in:
Yaswanth Gali 2025-04-17 12:48:51 +05:30 committed by GitHub
parent 688f4707bf
commit a2ef3cf537
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 6411 additions and 1 deletions

View File

@ -953,6 +953,8 @@
title: InstructBLIP
- local: model_doc/instructblipvideo
title: InstructBlipVideo
- local: model_doc/janus
title: Janus
- local: model_doc/kosmos-2
title: KOSMOS-2
- local: model_doc/layoutlm

View File

@ -0,0 +1,230 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Janus
## Overview
The Janus Model was originally proposed in [Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation](https://arxiv.org/abs/2410.13848) by DeepSeek AI team and later refined in [Janus-Pro: Unified Multimodal Understanding and Generation with Data and Model Scaling](https://arxiv.org/abs/2501.17811). Janus is a vision-language model that can generate both image and text output, it can also take both images and text as input.
> [!NOTE]
> The model doesn't generate both images and text in an interleaved format. The user has to pass a parameter indicating whether to generate text or image.
The abstract from the original paper is the following:
*In this paper, we introduce Janus, an autoregressive framework that unifies multimodal understanding and generation. Prior research often relies on a single visual encoder for both tasks, such as Chameleon. However, due to the differing levels of information granularity required by multimodal understanding and generation, this approach can lead to suboptimal performance, particularly in multimodal understanding. To address this issue, we decouple visual encoding into separate pathways, while still leveraging a single, unified transformer architecture for processing. The decoupling not only alleviates the conflict between the visual encoder's roles in understanding and generation, but also enhances the framework's flexibility. For instance, both the multimodal understanding and generation components can independently select their most suitable encoding methods. Experiments show that Janus surpasses previous unified model and matches or exceeds the performance of task-specific models. The simplicity, high flexibility, and effectiveness of Janus make it a strong candidate for next-generation unified multimodal models.*
The abstract from the aforementioned `Janus-Pro` paper, released afterwards, is the following:
*In this work, we introduce Janus-Pro, an advanced version of the previous work Janus. Specifically, Janus-Pro incorporates (1) an optimized training strate (2) expanded training data, and (3) scaling to larger model size. With these improvements, Janus-Pro achieves significant advancements in both multimodal understanding and text-to-image instruction-following capabilities, while also enhancing the stability of text-to-image generation. We hope this work will inspire further exploration in the field. Code and models are publicly available.*
This model was contributed by [Yaswanth Gali](https://huggingface.co/yaswanthgali) and [Hugo Silva](https://huggingface.co/hugosilva664).
The original code can be found [here](https://github.com/deepseek-ai/Janus).
## Usage Example
### Single image inference
Here is the example of visual understanding with a single image.
> [!NOTE]
> Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts.
```python
import torch
from PIL import Image
import requests
from transformers import JanusForConditionalGeneration, JanusProcessor
model_id = "deepseek-community/Janus-Pro-1B"
# Prepare Input for generation.
messages = [
{
"role": "user",
"content": [
{'type':'image', 'url': 'http://images.cocodataset.org/val2017/000000039769.jpg'},
{'type':"text", "text":"What do you see in this image?."}
]
},
]
# Set generation mode to `text` to perform text generation.
processor = JanusProcessor.from_pretrained(model_id)
model = JanusForConditionalGeneration.from_pretrained(model_id,
torch_dtype=torch.bfloat16,
device_map="auto")
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
generation_mode="text",
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)
output = model.generate(**inputs, max_new_tokens=40,generation_mode='text',do_sample=True)
text = processor.decode(output[0], skip_special_tokens=True)
print(text)
```
### Multi image inference
Janus can perform inference with multiple images as input, where images can belong to the same prompt or different prompts in batched inference, where the model processes many conversations in parallel. Here is how you can do it:
```python
import torch
from PIL import Image
import requests
from transformers import JanusForConditionalGeneration, JanusProcessor
model_id = "deepseek-community/Janus-Pro-1B"
image_urls = [
"http://images.cocodataset.org/val2017/000000039769.jpg",
"https://www.ilankelman.org/stopsigns/australia.jpg",
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
]
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats the difference between"},
{"type": "image", "url": image_urls[0]},
{"type": "text", "text": " and "},
{"type": "image", "url": image_urls[1]}
]
}
],
[
{
"role": "user",
"content": [
{"type": "image", "url": image_urls[2]},
{"type": "text", "text": "What do you see in this image?"}
]
}
]
]
# Load model and processor
processor = JanusProcessor.from_pretrained(model_id)
model = JanusForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
generation_mode="text",
tokenize=True,
padding=True,
return_dict=True,
return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)
# Generate response
output = model.generate(**inputs, max_new_tokens=40, generation_mode='text', do_sample=False)
text = processor.batch_decode(output, skip_special_tokens=True)
print(text)
```
## Text to Image generation
Janus can also generate images given a prompt.
```python
import torch
from transformers import JanusForConditionalGeneration, JanusProcessor
# Set generation mode to `image` to prepare inputs for image generation..
model_id = "deepseek-community/Janus-Pro-1B"
processor = JanusProcessor.from_pretrained(model_id)
model = JanusForConditionalGeneration.from_pretrained(model_id,
torch_dtype=torch.bfloat16,
device_map="auto")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "A dog running under the rain."},
],
}
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt,generation_mode="image",return_tensors="pt").to(model.device, dtype=torch.bfloat16)
# Set num_return_sequence parameter to generate multiple images per prompt.
model.generation_config.num_return_sequences = 2
outputs = model.generate(**inputs,
generation_mode="image",
do_sample=True,
use_cache=True,
)
# Perform post-processing on the generated token ids.
decoded_image = model.decode_image_tokens(outputs)
images = processor.postprocess(list(decoded_image.float()),return_tensors="PIL.Image.Image")
# Save the image
for i, image in enumerate(images['pixel_values']):
image.save(f"result{i}.png")
```
## JanusConfig
[[autodoc]] JanusConfig
## JanusVisionConfig
[[autodoc]] JanusVisionConfig
## JanusVQVAEConfig
[[autodoc]] JanusVQVAEConfig
## JanusProcessor
[[autodoc]] JanusProcessor
## JanusImageProcessor
[[autodoc]] JanusImageProcessor
## JanusVisionModel
[[autodoc]] JanusVisionModel
- forward
## JanusVQVAE
[[autodoc]] JanusVQVAE
- forward
## JanusModel
[[autodoc]] JanusModel
- forward
## JanusForConditionalGeneration
[[autodoc]] JanusForConditionalGeneration
- forward

View File

@ -144,6 +144,7 @@ if TYPE_CHECKING:
from .instructblip import *
from .instructblipvideo import *
from .jamba import *
from .janus import *
from .jetmoe import *
from .kosmos2 import *
from .layoutlm import *

View File

@ -163,6 +163,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("instructblip", "InstructBlipConfig"),
("instructblipvideo", "InstructBlipVideoConfig"),
("jamba", "JambaConfig"),
("janus", "JanusConfig"),
("jetmoe", "JetMoeConfig"),
("jukebox", "JukeboxConfig"),
("kosmos-2", "Kosmos2Config"),
@ -517,6 +518,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("instructblip", "InstructBLIP"),
("instructblipvideo", "InstructBlipVideo"),
("jamba", "Jamba"),
("janus", "Janus"),
("jetmoe", "JetMoe"),
("jukebox", "Jukebox"),
("kosmos-2", "KOSMOS-2"),

View File

@ -101,6 +101,7 @@ else:
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
("janus", ("JanusImageProcessor")),
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),

View File

@ -152,6 +152,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jamba", "JambaModel"),
("janus", "JanusModel"),
("jetmoe", "JetMoeModel"),
("jukebox", "JukeboxModel"),
("kosmos-2", "Kosmos2Model"),
@ -359,6 +360,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("idefics", "IdeficsForVisionText2Text"),
("idefics2", "Idefics2ForConditionalGeneration"),
("idefics3", "Idefics3ForConditionalGeneration"),
("janus", "JanusForConditionalGeneration"),
("layoutlm", "LayoutLMForMaskedLM"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
@ -858,6 +860,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("idefics2", "Idefics2ForConditionalGeneration"),
("idefics3", "Idefics3ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("janus", "JanusForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),

View File

@ -75,6 +75,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("idefics3", "Idefics3Processor"),
("instructblip", "InstructBlipProcessor"),
("instructblipvideo", "InstructBlipVideoProcessor"),
("janus", "JanusProcessor"),
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"),

View File

@ -265,6 +265,7 @@ else:
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("janus", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
(
"jetmoe",
(

View File

@ -755,7 +755,6 @@ class ChameleonVQVAEVectorQuantizer(nn.Module):
self.beta = getattr(config, "beta", 0.25)
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.re_embed = self.num_embeddings
def forward(self, hidden_state: torch.Tensor):
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()

View File

@ -0,0 +1,29 @@
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_janus import *
from .image_processing_janus import *
from .modeling_janus import *
from .processing_janus import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,314 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/janus/modular_janus.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_janus.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
class JanusVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`JanusVisionModel`]. It is used to instantiate a
`JanusVisionModel` according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
image_size (`int`, *optional*, defaults to 384):
The size (resolution) of each image.
attention_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for attention weights.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"`, and `"gelu_new"` are supported.
mlp_ratio (`float`, *optional*, defaults to 4.0):
Ratio of MLP hidden dimensionality to embedding dimensionality.
attention_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys, and values in the attention layers.
hidden_dropout_rate (`float`, *optional*, defaults to 0.0):
The dropout probability for fully connected layers in the encoder.
projection_dim (`int`, *optional*, defaults to 2048):
Dimensionality of the MLP projection head.
projection_dropout (`float`, *optional*, defaults to 0.0):
Dropout probability for the projection layer.
use_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to normalize the query and key matrices.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated normal initializer for initializing all weight matrices.
depth (`int`, *optional*, defaults to 2):
Number of hidden layers in the aligner module.
num_image_tokens (`int`, *optional*, defaults to 576):
Number of image tokens.
"""
model_type = "janus_vision_model"
base_config_key = "vision_config"
def __init__(
self,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
num_channels=3,
patch_size=16,
image_size=384,
attention_dropout=0.0,
layer_norm_eps=1e-6,
hidden_act="gelu",
mlp_ratio=4.0,
attention_bias=True,
hidden_dropout_rate=0.0,
projection_dim=2048,
projection_dropout=0.0,
use_qk_norm=False,
initializer_range=0.02,
depth=2,
num_image_tokens=576,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.mlp_ratio = mlp_ratio
self.attention_bias = attention_bias
self.hidden_dropout_rate = hidden_dropout_rate
self.projection_dim = projection_dim
self.projection_dropout = projection_dropout
self.use_qk_norm = use_qk_norm
self.initializer_range = initializer_range
self.depth = depth
self.num_image_tokens = num_image_tokens
class JanusVQVAEConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`JanusVQVAEModel`]. It is used to instantiate a
`JanusVQVAEModel` according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. Instantiating a
configuration with the defaults will yield a similar configuration to the VQModel of the
[deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B).
Args:
embed_dim (`int`, *optional*, defaults to 8):
Dimensionality of each embedding vector.
num_embeddings (`int`, *optional*, defaults to 16384):
Number of codebook embeddings.
double_latent (`bool`, *optional*, defaults to `False`):
Whether to use double z channels.
latent_channels (`int`, *optional*, defaults to 256):
Number of channels for the latent space.
num_patches (`int`, *optional*, defaults to 32):
Num of patches the input images can be divided into.
in_channels (`int`, *optional*, defaults to 3):
Number of input channels.
out_channels (`int`, *optional*, defaults to 3):
Number of out channels.
base_channels (`int`, *optional*, defaults to 128):
Base channel count.
channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
Channel multipliers for each resolution.
num_res_blocks (`int`, *optional*, defaults to 2):
Number of residual blocks.
dropout (`float`, *optional*, defaults to 0.0):
Dropout rate.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
projection_dim (`int`, *optional*, defaults to 2048):
Dimensionality of the MLP projection head.
num_hidden_layers (`int`, *optional*, defaults to 2):
Number of hidden layers in VAVAE MLP Connecter module.
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
image_token_embed_dim (`int`, *optional*, defaults to 2048):
Dimension of image embeddings. It should be same as the dimensionality of text embeddings.
"""
model_type = "janus_vqgan"
base_config_key = "vq_config"
def __init__(
self,
embed_dim: int = 8,
num_embeddings: int = 16384,
double_latent: bool = False,
latent_channels: int = 256,
num_patches: int = 32,
in_channels: int = 3,
out_channels: int = 3,
base_channels: int = 128,
channel_multiplier: List[int] = [1, 1, 2, 2, 4],
num_res_blocks: int = 2,
dropout: float = 0.0,
initializer_range=0.02,
projection_dim=2048,
num_hidden_layers=2,
hidden_act="gelu",
image_token_embed_dim=2048,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_embeddings = num_embeddings
self.double_latent = double_latent
self.latent_channels = latent_channels
self.in_channels = in_channels
self.base_channels = base_channels
self.channel_multiplier = channel_multiplier
self.num_res_blocks = num_res_blocks
self.dropout = dropout
self.initializer_range = initializer_range
self.num_patches = num_patches
self.out_channels = out_channels
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.hidden_act = hidden_act
self.image_token_embed_dim = image_token_embed_dim
class JanusConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`JanusModel`]. It is used to instantiate an
Janus model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Janus-1B or Janus-7B models.
e.g. [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B) or
[deepseek-community/Janus-Pro-7B](https://huggingface.co/deepseek-community/Janus-Pro-7B)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVisionConfig`):
The config object or dictionary of the vision backbone.
vq_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVQVAEConfig`):
The config object or dictionary of the VQVAE backbone.
Example:
```python
>>> from transformers import JanusForConditionalGeneration, JanusConfig, JanusVisionConfig, JanusVQVAEConfig, LlamaConfig
>>> # Initializing a Janus vision config
>>> vision_config = JanusVisionConfig()
>>> # Initializing a Llama config
>>> text_config = LlamaConfig()
>>> # Initializing a VQ config
>>> vq_config = JanusVQVAEConfig()
>>> # Initializing a Janus Pro 1B style configuration
>>> configuration = JanusConfig(vision_config=vision_config, text_config=text_config, vq_config=vq_config)
>>> # Initializing a model from the Janus Pro 1B style configuration
>>> model = JanusForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "janus"
sub_configs = {
"text_config": AutoConfig,
"vision_config": JanusVisionConfig,
"vq_config": JanusVQVAEConfig,
}
def __init__(self, text_config=None, vision_config=None, vq_config=None, **kwargs):
if isinstance(text_config, dict):
text_config["model_type"] = text_config.get("model_type", "llama")
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("`text_config` is None. Initializing with default values")
self.text_config = CONFIG_MAPPING["llama"]()
elif isinstance(text_config, PretrainedConfig):
self.text_config = text_config
else:
raise ValueError(
f"Invalid type for `text_config`. Must be either `dict` or `LlamaConfig`."
f" Type found: {type(text_config)}"
)
if vision_config is None:
logger.info("`vision_config` is None. Initializing with default JanusVisionConfig values")
self.vision_config = JanusVisionConfig()
elif isinstance(vision_config, dict):
self.vision_config = JanusVisionConfig(**vision_config)
elif isinstance(vision_config, JanusVisionConfig):
self.vision_config = vision_config
else:
raise ValueError(
f"Invalid type for `vision_config`. Must be either `dict` or `JanusVisionConfig`."
f" Type found: {type(vision_config)}"
)
if vq_config is None:
logger.info("`vq_config` is None. Initializing with default JanusVQVAEConfig values")
self.vq_config = JanusVQVAEConfig()
elif isinstance(vq_config, dict):
self.vq_config = JanusVQVAEConfig(**vq_config)
elif isinstance(vq_config, JanusVQVAEConfig):
self.vq_config = vq_config
else:
raise ValueError(
f"Invalid type for `vq_config`. Must be either `dict` or `JanusVQVAEConfig`."
f" Type found: {type(vq_config)}"
)
# This dimension is required when decoding discrete image tokens to continuous input.
self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size
# The default is only the index for the 1B model, 7B uses a different one
self.image_token_index = kwargs.get("image_token_index", 100581)
super().__init__(**kwargs)
__all__ = ["JanusVQVAEConfig", "JanusVisionConfig", "JanusConfig"]

View File

@ -0,0 +1,501 @@
# coding=utf-8
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of run command (run from root):
python src/transformers/models/janus/convert_janus_weights_to_hf.py --repo_id deepseek-ai/Janus-Pro-1B --local_dir tmp/hub_code_in --output_dir tmp/hub_code_out --safe_serialization
Using provided local directory: tmp/hub_code_in
"""
import argparse
import gc
import json
import os
import re
import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import (
AutoTokenizer,
JanusConfig,
JanusForConditionalGeneration,
JanusVisionConfig,
JanusVQVAEConfig,
LlamaConfig,
)
from transformers.models.janus.image_processing_janus import JanusImageProcessor
from transformers.models.janus.processing_janus import JanusProcessor
# Mappings
MAPPINGS = {
# Vision model
r"(?<!gen_)vision_model\.vision_tower\.blocks\.(\d+)\.attn": r"model.vision_model.encoder.layers.\1.self_attn",
r"(?<!gen_)vision_model.vision_tower.blocks": "model.vision_model.encoder.layers",
r"(?<!gen_)vision_model.vision_tower.pos_embed": "model.vision_model.embeddings.position_embedding.weight",
r"(?<!gen_)vision_model.vision_tower.patch_embed.proj": "model.vision_model.embeddings.patch_embedding",
r"(?<!gen_)vision_model.vision_tower.norm": "model.vision_model.post_layernorm",
r"(?P<pre>\b(vision_model|model\.vision_model)\b.*\.)proj(?=\.|\s|$)": r"\g<pre>projection_layer",
r"(?P<pre>\b(vision_model|model\.vision_model)\b.*\.)norm(?=\.|\s|$)": r"\g<pre>layer_norm",
r"(?P<pre>\b(vision_model|model\.vision_model)\b.*\.)norm1(?=\.|\s|$)": r"\g<pre>layer_norm1",
r"(?P<pre>\b(vision_model|model\.vision_model)\b.*\.)norm2(?=\.|\s|$)": r"\g<pre>layer_norm2",
r"\bvision_model\.vision_tower\.attn_pool\.[^\s$]*": None,
# VQ Model
r"gen_vision_model": "model.vqmodel",
r"(?P<pre>\b(gen_vision_model|model\.vqmodel)\b.*\.)decoder\.conv_blocks(?=\.|\s|$)": r"\g<pre>decoder.up",
r"(?P<pre>\b(gen_vision_model|model\.vqmodel)\b.*\.)encoder\.conv_blocks(?=\.|\s|$)": r"\g<pre>encoder.down",
r"(?P<pre>\b(gen_vision_model|model\.vqmodel)\b.*\.)res(?=\.|\s|$)": r"\g<pre>block",
r"(?P<pre>\b(gen_vision_model|model\.vqmodel)\b.*\.)mid\.0(?=\.|\s|$)": r"\g<pre>mid.block_1",
r"(?P<pre>\b(gen_vision_model|model\.vqmodel)\b.*\.)mid\.1(?=\.|\s|$)": r"\g<pre>mid.attn_1",
r"(?P<pre>\b(gen_vision_model|model\.vqmodel)\b.*\.)mid\.2(?=\.|\s|$)": r"\g<pre>mid.block_2",
# Aligner Modules
r"(gen_aligner)\.layers\.0": r"model.generation_aligner.fc1",
r"(gen_aligner)\.layers\.2": r"model.generation_aligner.hidden_layers.0",
r"(?<!gen_)(aligner)\.layers\.0": r"model.aligner.fc1",
r"(?<!gen_)(aligner)\.layers\.2": r"model.aligner.hidden_layers.0",
"gen_head.output_mlp_projector": "model.generation_head.proj_out",
r"(\s|^)gen_embed": r"\1model.generation_embeddings",
r"(\s|^)gen_head": r"\1model.generation_head",
r"\b(gen_vision_model|model\.vqmodel)\.quantize\.codebook_used": None,
# Language model
r"(\s|^)language_model\.model": r"\1model.language_model",
r"\b(model\.language_model|(?<!model\.)language_model)\.lm_head\.weight": "lm_head.weight",
}
CHAT_TEMPLATE = (
"{%set seps=['\n\n','<\uff5cend\u2581of\u2581sentence\uff5c>']%}"
"{%set i=0%}"
"{%for message in messages%}"
"{%if message['role']|lower=='user'%}"
"<|User|>: "
"{%elif message['role']|lower=='assistant'%}"
"<|Assistant|>:{%if not (loop.last and not add_generation_prompt and message['content'][0]['type']=='text' and message['content'][0]['text']=='')%} {%endif%}"
"{%else%}"
"{{message['role'].capitalize()}}: "
"{%endif%}"
"{%for content in message['content']%}"
"{%if content['type']=='image'%}"
"{%if not loop.first%}{{'\n'}}{%endif%}"
"<image_placeholder>"
"{%if not loop.last%}{{'\n'}}{%endif%}"
"{%elif content['type']=='text'%}"
"{%set text=content['text']%}"
"{%if loop.first%}{%set text=text.lstrip()%}{%endif%}"
"{%if loop.last%}{%set text=text.rstrip()%}{%endif%}"
"{%if not loop.first and message['content'][loop.index0-1]['type']=='text'%}"
"{{' '+text}}"
"{%else%}"
"{{text}}"
"{%endif%}"
"{%endif%}"
"{%endfor%}"
"{%if not loop.last or add_generation_prompt%}"
"{%if message['role']|lower=='user'%}"
"{{seps[0]}}"
"{%else%}"
"{{seps[1]}}"
"{%endif%}"
"{%endif%}"
"{%endfor%}"
"{%if add_generation_prompt%}<|Assistant|>:{%endif%}"
)
def convert_old_keys_to_new_keys(state_dict):
keys_as_text = "\n".join(state_dict.keys())
new_keys_as_text = keys_as_text
for old, repl in MAPPINGS.items():
if repl is None:
new_keys_as_text = re.sub(old, "", new_keys_as_text)
else:
new_keys_as_text = re.sub(old, repl, new_keys_as_text)
output_dict = dict(zip(keys_as_text.split("\n"), new_keys_as_text.split("\n")))
return output_dict
def split_tensor(tensor, key):
"""Splits a merged tensor (qkv or kv) into separate tensors and creates keys for each part."""
if "qkv" in key:
prefix_to_replace = "qkv"
num_splits = 3
new_keys = ["q_proj", "k_proj", "v_proj"]
elif "kv" in key:
prefix_to_replace = "kv"
num_splits = 2
new_keys = ["k_proj", "v_proj"]
else:
raise ValueError(f"Unrecognized tensor type in key: {key}")
split_size = tensor.shape[0] // num_splits
tensors = torch.split(tensor, split_size, dim=0)
return {key.replace(prefix_to_replace, new_keys[i]): tensors[i] for i in range(num_splits)}
def convert_state_dict_to_hf(state_dict):
"""Convert state dict keys to HF format."""
conversion_dict = convert_old_keys_to_new_keys(state_dict)
converted_state_dict = {}
for old_key, new_key in conversion_dict.items():
if new_key:
if "qkv" in new_key or "kv" in new_key: # Detect merged attention keys and split them.
qkv_split_dict = split_tensor(state_dict[old_key], new_key)
converted_state_dict.update(qkv_split_dict)
else:
converted_state_dict[new_key] = state_dict[old_key]
# Embeddings will not have initial dimension
pos_embed_key = "model.vision_model.embeddings.position_embedding.weight"
converted_state_dict[pos_embed_key] = converted_state_dict[pos_embed_key].squeeze(0)
return converted_state_dict
def ensure_model_downloaded(repo_id: str = None, revision: str = None, local_dir: str = None) -> str:
"""
Ensures model files are downloaded locally, downloads them if not.
Returns path to local files.
Args:
repo_id: The Hugging Face model repo ID (required if local_dir not provided)
revision: Optional git revision to use
local_dir: Optional local directory path where model files should be stored/found
"""
if local_dir is not None:
if os.path.exists(local_dir):
print(f"Using provided local directory: {local_dir}")
else:
# Create the local directory if it doesn't exist
os.makedirs(local_dir, exist_ok=True)
print(f"Created local directory: {local_dir}")
if repo_id is None:
raise ValueError("Either repo_id or local_dir must be provided")
print(f"Ensuring {repo_id} (revision: {revision or 'latest'}) is downloaded...")
try:
# First try to find files locally
download_dir = snapshot_download(repo_id, revision=revision, local_files_only=True, local_dir=local_dir)
print(f"Found model files locally at {download_dir}")
return download_dir
except Exception:
# If files not found locally, download them
print(f"Downloading model files for {repo_id}...")
download_dir = snapshot_download(repo_id, revision=revision, local_files_only=False, local_dir=local_dir)
print(f"Downloaded model files to {download_dir}")
return download_dir
def load_model_state_dict(input_path: str) -> dict:
"""
Load model state dict, handling both single and sharded files.
"""
index_path = os.path.join(input_path, "pytorch_model.bin.index.json")
single_file_path = os.path.join(input_path, "pytorch_model.bin")
# Check if we have a sharded model
if os.path.exists(index_path):
print("Loading sharded model...")
state_dict = {}
with open(index_path, "r") as f:
index = json.load(f)
# Get unique shard files and load each one only once
unique_shard_files = sorted(set(index["weight_map"].values()))
for shard_file in unique_shard_files:
print(f"Loading shard {shard_file}...")
shard_path = os.path.join(input_path, shard_file)
shard_dict = torch.load(shard_path, map_location="cpu")
state_dict.update(shard_dict)
return state_dict
# Single file model
elif os.path.exists(single_file_path):
print("Loading single file model...")
return torch.load(single_file_path, map_location="cpu")
else:
raise ValueError(f"No model files found in {input_path}")
def convert_model(
repo_id=None,
local_dir=None,
text_model_id=None,
output_dir=None,
output_hub_path=None,
safe_serialization=True,
revision=None,
):
"""Convert and save the model weights, processor, and configuration."""
if output_dir is None and output_hub_path is None:
raise ValueError("At least one of output_dir or output_hub_path must be specified")
if repo_id is None and local_dir is None:
raise ValueError("Either repo_id or local_dir must be specified")
# Create output directory if specified
if output_dir:
os.makedirs(output_dir, exist_ok=True)
print(f"Created/verified output directory: {output_dir}")
torch.set_default_dtype(torch.float16)
# Download or locate model files
input_path = ensure_model_downloaded(repo_id=repo_id, revision=revision, local_dir=local_dir)
# Load configuration files
required_files = ["config.json", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"]
missing_files = [f for f in required_files if not os.path.exists(os.path.join(input_path, f))]
if missing_files:
raise ValueError(
f"The following required configuration files are missing from {input_path}: {', '.join(missing_files)}. "
"Please ensure you have downloaded all necessary model files."
)
with open(os.path.join(input_path, "config.json"), "r") as f:
config_data = json.load(f)
with open(os.path.join(input_path, "preprocessor_config.json"), "r") as f:
preprocessor_config = json.load(f)
with open(os.path.join(input_path, "special_tokens_map.json"), "r") as f:
special_tokens_map = json.load(f)
with open(os.path.join(input_path, "tokenizer_config.json"), "r") as f:
tokenizer_config = json.load(f)
# Create tokenizer directly from tokenizer.json if it exists
tokenizer_json_path = os.path.join(input_path, "tokenizer.json")
special_image_tokens = {
"image_token": "<image_placeholder>",
"boi_token": "<begin_of_image>",
"eoi_token": "<end_of_image>",
}
if os.path.exists(tokenizer_json_path) and not text_model_id:
tokenizer = AutoTokenizer.from_pretrained(
input_path, # This will load tokenizer.json directly
model_max_length=tokenizer_config["model_max_length"],
extra_special_tokens=special_image_tokens,
)
else:
# Fallback to creating from text_model_id with special tokens
tokenizer = AutoTokenizer.from_pretrained(
text_model_id,
bos_token=special_tokens_map["bos_token"],
eos_token=special_tokens_map["eos_token"],
pad_token=special_tokens_map["pad_token"],
additional_special_tokens=special_tokens_map["additional_special_tokens"],
model_max_length=tokenizer_config["model_max_length"],
extra_special_tokens=special_image_tokens,
)
# Create image processor from config
image_processor_kwargs = {}
for key in ["do_normalize", "image_mean", "image_std", "min_size", "rescale_factor"]:
if key in preprocessor_config:
image_processor_kwargs[key] = preprocessor_config[key]
if "image_size" in preprocessor_config:
image_processor_kwargs["size"] = {
"height": preprocessor_config["image_size"],
"width": preprocessor_config["image_size"],
}
image_processor = JanusImageProcessor(**image_processor_kwargs)
# Create processor with chat template
processor = JanusProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
chat_template=CHAT_TEMPLATE,
use_default_system_prompt=True,
)
if output_dir:
print(f"Saving processor to {output_dir}...")
processor.save_pretrained(output_dir)
if output_hub_path:
print(f"Pushing processor to hub at {output_hub_path}...")
processor.push_to_hub(output_hub_path)
# Create model configurations
text_config_kwargs = {}
for key in [
"vocab_size",
"hidden_size",
"intermediate_size",
"num_hidden_layers",
"num_attention_heads",
"num_key_value_heads",
"hidden_act",
"max_position_embeddings",
"torch_dtype",
]:
if key in config_data["language_config"]:
text_config_kwargs[key] = config_data["language_config"][key]
# Add token IDs from tokenizer
text_config_kwargs.update(
{
"pad_token_id": tokenizer.pad_token_id,
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
)
text_config = LlamaConfig(**text_config_kwargs)
# Create vision config
vision_config_kwargs = {}
if "image_size" in config_data["vision_config"]["params"]:
vision_config_kwargs["image_size"] = config_data["vision_config"]["params"]["image_size"]
# Add aligner params if present
if "aligner_config" in config_data and "params" in config_data["aligner_config"]:
if "n_embed" in config_data["aligner_config"]["params"]:
vision_config_kwargs["projection_dim"] = config_data["aligner_config"]["params"]["n_embed"]
if "depth" in config_data["aligner_config"]["params"]:
vision_config_kwargs["depth"] = config_data["aligner_config"]["params"]["depth"]
vision_config = JanusVisionConfig(**vision_config_kwargs)
vq_config = JanusVQVAEConfig(
embed_dim=config_data["gen_vision_config"]["params"]["n_embed"],
num_embeddings=config_data["gen_vision_config"]["params"]["image_token_size"],
projection_dim=config_data["gen_aligner_config"]["params"]["n_embed"],
depth=config_data["gen_aligner_config"]["params"]["depth"],
image_token_embed_dim=config_data["gen_head_config"]["params"]["image_token_embed"],
)
# Create the main config
config = JanusConfig(
text_config=text_config,
vision_config=vision_config,
vq_config=vq_config,
image_token_index=tokenizer.vocab.get("<image_placeholder>"),
)
# Save the config
if output_dir:
config.save_pretrained(output_dir)
if output_hub_path:
config.push_to_hub(output_hub_path)
# Initialize model with empty weights
print("Creating empty model...")
with init_empty_weights():
model = JanusForConditionalGeneration(config)
model.generation_config.temperature = 1
model.generation_config.guidance_scale = 5
model.generation_config.pad_token_id = tokenizer.vocab.get("<\uff5c\u2581pad\u2581\uff5c>")
model.generation_config.generation_kwargs["boi_token_id"] = tokenizer.vocab.get("<begin_of_image>")
# Load and convert state dict
print("Loading state dict...")
state_dict = load_model_state_dict(input_path)
state_dict = convert_state_dict_to_hf(state_dict)
# Load converted state dict
print("Loading converted weights into model...")
model.load_state_dict(state_dict, strict=True, assign=True)
# Tie weights before any device mapping
print("Tying weights...")
model.tie_weights()
# Save the model
if output_dir:
print(f"Saving model to {output_dir}...")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
if output_hub_path:
print(f"Pushing model to hub at {output_hub_path}...")
model.push_to_hub(output_hub_path, safe_serialization=safe_serialization)
del state_dict, model
gc.collect()
# Validate the saved model if saved locally
if output_dir:
print("Reloading the local model to check if it's saved correctly...")
# TODO: warning about weights not being tied is raised here regardless of model.tie_weights() above
JanusForConditionalGeneration.from_pretrained(output_dir, device_map="auto")
print("Local model reloaded successfully.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo_id",
help="HuggingFace Hub repo ID for the model",
default=None,
)
parser.add_argument(
"--local_dir",
help="Local directory containing the model files",
default=None,
)
parser.add_argument(
"--revision",
help="Specific revision to download from the Hub",
default=None,
)
parser.add_argument(
"--output_dir",
help="Location to write HF model locally",
default=None,
)
parser.add_argument(
"--output_hub_path",
help="Repository ID to push model to hub (e.g. 'username/model-name')",
default=None,
)
parser.add_argument(
"--text_model_id",
help="Hub ID of the text model to get tokenizer from. Optional if tokenizer.json exists in the model directory.",
required=False,
)
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to save using safetensors",
)
args = parser.parse_args()
if args.output_dir is None and args.output_hub_path is None:
raise ValueError("At least one of --output_dir or --output_hub_path must be specified")
if args.repo_id is None and args.local_dir is None:
raise ValueError("Either --repo_id or --local_dir must be specified")
convert_model(
repo_id=args.repo_id,
local_dir=args.local_dir,
text_model_id=args.text_model_id,
output_dir=args.output_dir,
output_hub_path=args.output_hub_path,
safe_serialization=args.safe_serialization,
revision=args.revision,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,508 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/janus/modular_janus.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_janus.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_flat_list_of_images,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
TensorType,
filter_out_non_signature_kwargs,
is_vision_available,
logging,
)
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
class JanusImageProcessor(BaseImageProcessor):
r"""
Constructs a JANUS image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
min_size (`int`, *optional*, defaults to 14):
The minimum allowed size for the resized image. Ensures that neither the height nor width
falls below this value after resizing.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
min_size: int = 14,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 384, "width": 384}
size = get_size_dict(size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.min_size = min_size
if image_mean is None:
self.background_color = (127, 127, 127)
else:
self.background_color = tuple([int(x * 255) for x in image_mean])
def resize(
self,
image: np.ndarray,
size: Union[Dict[str, int], int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to dynamically calculated size.
Args:
image (`np.ndarray`):
Image to resize.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `None`: will be inferred from input
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, input_data_format)
max_size = max(height, width)
size = get_size_dict(size, default_to_square=True)
if size["height"] != size["width"]:
raise ValueError(
f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
)
size = size["height"]
delta = size / max_size
# Largest side becomes `size` and the other side is scaled according to the aspect ratio.
output_size_nonpadded = [
max(int(height * delta), self.min_size),
max(int(width * delta), self.min_size),
]
image = resize(
image,
size=output_size_nonpadded,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
# Expand and pad the images to obtain a square image of dimensions `size x size`
image = self.pad_to_square(
image=image,
background_color=self.background_color,
input_data_format=input_data_format,
)
return image
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
do_convert_rgb: Optional[bool] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The shortest edge of the image is resized to
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_rescale and is_scaled_image(images[0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs
def pad_to_square(
self,
image: np.ndarray,
background_color: Union[int, Tuple[int, int, int]] = 0,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Pads an image to a square based on the longest edge.
Args:
image (`np.ndarray`):
The image to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The padded image.
"""
height, width = get_image_size(image, input_data_format)
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
if height == width:
image = (
to_channel_dimension_format(image, data_format, input_data_format)
if data_format is not None
else image
)
return image
max_dim = max(height, width)
# Ensure background_color is the correct shape
if isinstance(background_color, int):
background_color = [background_color]
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
if input_data_format == ChannelDimension.FIRST:
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
for i, color in enumerate(background_color):
result[i, :, :] = color
if width > height:
start = (max_dim - height) // 2
result[:, start : start + height, :] = image
else:
start = (max_dim - width) // 2
result[:, :, start : start + width] = image
else:
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
for i, color in enumerate(background_color):
result[:, :, i] = color
if width > height:
start = (max_dim - height) // 2
result[start : start + height, :, :] = image
else:
start = (max_dim - width) // 2
result[:, start : start + width, :] = image
return result
def postprocess(
self,
images: ImageInput,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: List[float] = None,
image_std: List[float] = None,
input_data_format: str = None,
return_tensors: str = None,
):
"""Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing."""
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
images = make_list_of_images(images) # Ensures input is a list
if isinstance(images[0], PIL.Image.Image):
return images if len(images) > 1 else images[0]
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0]) # Determine format dynamically
pixel_values = []
for image in images:
image = to_numpy_array(image) # Ensure NumPy format
if do_normalize:
image = self.unnormalize(
image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format
)
if do_rescale:
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
image = image.clip(0, 255).astype(np.uint8)
if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
image = PIL.Image.fromarray(image)
pixel_values.append(image)
data = {"pixel_values": pixel_values}
return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
return BatchFeature(data=data, tensor_type=return_tensors)
def unnormalize(
self,
image: np.array,
image_mean: Union[float, Iterable[float]],
image_std: Union[float, Iterable[float]],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`.
image = (image * image_std) + image_mean
Args:
image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`):
Batch of pixel values to postprocess.
image_mean (`float` or `Iterable[float]`):
The mean to use for unnormalization.
image_std (`float` or `Iterable[float]`):
The standard deviation to use for unnormalization.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
num_channels = 3
if isinstance(image_mean, Iterable):
if len(image_mean) != num_channels:
raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}")
else:
image_mean = [image_mean] * num_channels
if isinstance(image_std, Iterable):
if len(image_std) != num_channels:
raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}")
else:
image_std = [image_std] * num_channels
rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std))
rev_image_std = tuple(1 / std for std in image_std)
image = self.normalize(
image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format
)
return image
__all__ = ["JanusImageProcessor"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,188 @@
# coding=utf-8
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Janus.
"""
from typing import List, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
from ...tokenization_utils_base import (
PreTokenizedInput,
TextInput,
)
from ...utils import logging
logger = logging.get_logger(__name__)
DEFAULT_SYSTEM_PROMPT = (
"You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.\n\n"
)
class JanusTextKwargs(TextKwargs, total=False):
generation_mode: str
class JanusProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: JanusTextKwargs
_defaults = {
"text_kwargs": {"padding": False, "generation_mode": "text"},
"common_kwargs": {"return_tensors": "pt"},
}
class JanusProcessor(ProcessorMixin):
r"""
Constructs a Janus processor which wraps a Janus Image Processor and a Llama tokenizer into a single processor.
[`JanusProcessor`] offers all the functionalities of [`JanusImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~JanusProcessor.__call__`] and [`~JanusProcessor.decode`] for more information.
Args:
image_processor ([`JanusImageProcessor`]):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`]):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
use_default_system_prompt (`str`, *optional*, defaults to `False`):
Use default system prompt for Text Generation.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template", "use_default_system_prompt"]
image_processor_class = "JanusImageProcessor"
tokenizer_class = "LlamaTokenizerFast"
def __init__(self, image_processor, tokenizer, chat_template=None, use_default_system_prompt=False, **kwargs):
self.num_image_tokens = 576
self.image_token = tokenizer.image_token
self.image_start_token = tokenizer.boi_token
self.image_end_token = tokenizer.eoi_token
self.use_default_system_prompt = use_default_system_prompt
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
videos=None,
audio=None,
**kwargs: Unpack[JanusProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
JanusImageProcessor's [`~JanusImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
JanusProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs
)
if text is None and images is None:
raise ValueError("You must specify either text or images.")
if text is not None:
if isinstance(text, str):
text = [text]
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
generation_mode = output_kwargs["text_kwargs"].pop("generation_mode")
# Replace the image token with expanded image tokens.
prompt_strings = []
one_img_tokens = self.image_start_token + (self.image_token * self.num_image_tokens) + self.image_end_token
for prompt in text:
prompt = prompt.replace(self.image_token, one_img_tokens)
if self.use_default_system_prompt and generation_mode == "text":
prompt = DEFAULT_SYSTEM_PROMPT + prompt
if generation_mode == "image":
prompt += self.image_start_token
prompt_strings.append(prompt)
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
# Process images if pixel values are provided.
if images is not None and generation_mode != "image":
data["pixel_values"] = self.image_processor(images=images, **output_kwargs["images_kwargs"])[
"pixel_values"
]
return BatchFeature(data=data)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def postprocess(self, images: ImageInput, **kwargs):
"""
Forwards all arguments to the image processor's `postprocess` method.
Refer to the original method's docstring for more details.
"""
return self.image_processor.postprocess(images, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
__all__ = ["JanusProcessor"]

View File

@ -126,6 +126,7 @@ VLM_CLASS_NAMES = [
"qwen2vl",
"qwen2_5_vl",
"ayavision",
"janus",
"gemma3",
"mistral3",
"chameleon",

View File

View File

@ -0,0 +1,188 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import JanusImageProcessor
class JanusImageProcessingTester:
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=384,
min_resolution=30,
max_resolution=200,
do_resize=True,
size=None,
do_normalize=True,
image_mean=[1.0, 1.0, 1.0],
image_std=[1.0, 1.0, 1.0],
do_convert_rgb=True,
):
size = size if size is not None else {"height": 384, "width": 384}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
}
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class JanusImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = JanusImageProcessor if is_vision_available() else None
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Janus
def setUp(self):
super().setUp()
self.image_processor_tester = JanusImageProcessingTester(self)
@property
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 384, "width": 384})
self.assertEqual(image_processor.image_mean, [1.0, 1.0, 1.0])
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size=42, image_mean=[1.0, 2.0, 1.0]
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
self.assertEqual(image_processor.image_mean, [1.0, 2.0, 1.0])
def test_call_pil(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test Non batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 384, 384)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 384, 384)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_numpy(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 384, 384)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 384, 384)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pytorch(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 384, 384)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 384, 384)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_nested_input(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
# Test batched as a list of images.
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 384, 384)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched as a nested list of images, where each sublist is one batch.
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 384, 384)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
# Image processor should return same pixel values, independently of input format.
self.assertTrue((encoded_images_nested == encoded_images).all())
@unittest.skip(reason="Not supported")
def test_call_numpy_4_channels(self):
pass

View File

@ -0,0 +1,553 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Janus model."""
import re
import tempfile
import unittest
from functools import reduce
import numpy as np
import requests
from transformers import (
AutoProcessor,
JanusConfig,
JanusForConditionalGeneration,
JanusModel,
JanusVQVAE,
JanusVQVAEConfig,
is_torch_available,
is_vision_available,
)
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.testing_utils import (
require_torch,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
class JanusVisionText2TextModelTester:
def __init__(
self,
parent,
image_token_index=0,
seq_length=25,
initializer_range=0.02,
text_config={
"model_type": "llama",
"seq_length": 7,
"is_training": True,
"use_input_mask": True,
"use_token_type_ids": False,
"use_labels": True,
"vocab_size": 99,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 16,
"type_sequence_label_size": 2,
"initializer_range": 0.02,
"num_labels": 3,
"num_choices": 4,
"pad_token_id": 1,
},
is_training=True,
vision_config={
"use_labels": True,
"image_size": 20,
"patch_size": 5,
"num_image_tokens": 4,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"projection_dim": 32,
"num_key_value_heads": 1,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"mlp_ratio": 2,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
"vision_feature_select_strategy": "default",
"vision_feature_layer": -1,
},
use_cache=False,
vq_num_embeds=12,
vq_embed_dim=12,
vq_channel_multiplier=[1, 1],
):
self.parent = parent
self.initializer_range = initializer_range
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.image_token_index = image_token_index
self.text_config = text_config
self.vision_config = vision_config
self.seq_length = seq_length
self.pad_token_id = text_config["pad_token_id"]
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = vision_config["num_channels"]
self.image_size = vision_config["image_size"]
self.num_image_tokens = vision_config["num_image_tokens"]
self.use_cache = use_cache
# vq model params
self.vq_num_embeds = vq_num_embeds
self.vq_embed_dim = vq_embed_dim
self.vq_channel_multiplier = vq_channel_multiplier
def get_vq_config(self):
return {
"embed_dim": self.vq_embed_dim,
"num_embeddings": self.vq_num_embeds,
"latent_channels": self.vq_embed_dim,
"in_channels": 3,
"base_channels": 32, # we have a GroupNorm of 32 groups, so can't do less
"channel_multiplier": self.vq_channel_multiplier,
"initializer_range": self.initializer_range,
"projection_dim": 10,
"image_token_embed_dim": 32, # Same as text model hidden size
}
def get_config(self):
return JanusConfig(
text_config=self.text_config,
vision_config=self.vision_config,
vq_config=self.get_vq_config(),
)
def prepare_config_and_inputs(self):
config = self.get_config()
pixel_values = floats_tensor(
[
self.batch_size,
3,
self.image_size,
self.image_size,
]
)
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == self.image_token_index] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = self.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
"generation_mode": "text", # Required to perform text generation instead of image generation.
}
return config, inputs_dict
@require_torch
class JanusVisionText2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (JanusModel, JanusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (JanusForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = JanusVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=JanusConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["generation_mode"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# Overwrite inputs_embeds tests because we need to delete "pixel values" for VLMs.
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["generation_mode"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# SigLip has one shared cls attr for all models, so we assign both submodels heer
vision_attn = language_attn = "sdpa" if model._supports_sdpa else "eager"
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "language_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.language_model.config._attn_implementation == language_attn)
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if any(re.finditer(r"Attention(?!Pool)", class_name)):
self.assertTrue(submodule.config._attn_implementation == "eager")
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if any(re.finditer(r"Attention(?!Pool)", class_name)):
self.assertTrue(submodule.config._attn_implementation == "sdpa")
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training:
self.skipTest(reason="ModelTester is not configured to run training tests")
"""
We skip some parameters when checking for gradient checkpointing:
- VQ model, as its training is not supported.
- A few other modules used for image generation.
"""
skip_patterns = ["vqmodel", "generation_embeddings", "generation_aligner", "generation_head"]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
if (
model_class.__name__
in [
*get_values(MODEL_MAPPING_NAMES),
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
]
or not model_class.supports_gradient_checkpointing
):
# TODO (ydshieh): use `skipTest` once pytest-dev/pytest-subtests/pull/169 is merged
# self.skipTest(reason=f"`supports_gradient_checkpointing` is False for {model_class.__name__}.")
continue
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train()
# unfreeze additional layers
for p in model.parameters():
p.requires_grad_(True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
if self.test_all_params_have_gradient:
for k, v in model.named_parameters():
if v.requires_grad and not reduce(lambda t, s: t | (s in k), skip_patterns, False):
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
else:
pass
class JanusVQModelTester:
def __init__(
self,
parent,
batch_size=5,
is_training=False,
initializer_range=0.02,
image_size=30,
num_embeds=12,
base_channels=32, # we have a GroupNorm of 32 groups, so can't do less
embed_dim=12,
channel_multiplier=[1, 2],
patch_size=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.initializer_range = initializer_range
self.image_size = image_size
self.base_channels = base_channels
self.num_embeds = num_embeds
self.embed_dim = embed_dim
self.channel_multiplier = channel_multiplier
self.num_patches = image_size // patch_size
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return JanusVQVAEConfig(
embed_dim=self.embed_dim,
num_embeddings=self.num_embeds,
latent_channels=self.embed_dim,
in_channels=3,
base_channels=self.base_channels,
channel_multiplier=self.channel_multiplier,
initializer_range=self.initializer_range,
resolution=self.image_size,
num_patches=self.num_patches,
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class JanusVQModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (JanusVQVAE,) if is_torch_available() else ()
test_head_masking = False
test_pruning = False
fx_compatible = False
has_attentions = False
test_resize_embeddings = False
def setUp(self):
self.model_tester = JanusVQModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=JanusVQVAEConfig,
has_text_modality=False,
common_properties=["embed_dim", "num_embeddings"],
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
def test_cpu_offload(self):
pass
@unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
def test_disk_offload_bin(self):
pass
@unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
def test_disk_offload_safetensors(self):
pass
@unittest.skip("Janus VQ module has no hidden states")
def test_hidden_states_output(self):
pass
@unittest.skip("Janus VQ module has no hidden states")
def test_model_outputs_equivalence(self):
pass
@unittest.skip("Janus VQ module has no get/set embeddings method")
def test_model_get_set_embeddings(self):
pass
@unittest.skip("Janus VQ module has no hidden states")
def test_retain_grad_hidden_states_attentions(self):
pass
class JanusIntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "deepseek-community/Janus-Pro-1B"
@slow
def test_model_text_generation(self):
model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
model.eval()
processor = AutoProcessor.from_pretrained(self.model_id)
image = Image.open(
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
)
prompt = "<image_placeholder>\nDescribe what do you see here and tell me about the history behind it?"
inputs = processor(images=image, text=prompt, generation_mode="text", return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=20, generation_mode="text", do_sample=False)
EXPECTED_DECODED_TEXT = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\n\nDescribe what do you see here and tell me about the history behind it?\n\nThe image depicts the constellation of Leo, which is often referred to as the "Lion"' # fmt: skip
text = processor.decode(output[0], skip_special_tokens=True)
self.assertEqual(
text,
EXPECTED_DECODED_TEXT,
)
@slow
def test_model_text_generation_batched(self):
model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(self.model_id)
image_1 = Image.open(
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
)
image_2 = Image.open(
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
)
prompts = [
"<image_placeholder>\nDescribe what do you see here and tell me about the history behind it?",
"What constellation is this image showing?<image_placeholder>\n",
]
inputs = processor(
images=[image_1, image_2], text=prompts, generation_mode="text", padding=True, return_tensors="pt"
).to(model.device, torch.float16)
EXPECTED_TEXT_COMPLETION = [
'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\n\nDescribe what do you see here and tell me about the history behind it?\n\nThe image depicts the constellation of Leo, which is often referred to as the "Lion"',
"You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\nWhat constellation is this image showing?\n\nThe image shows a constellation that is shaped like a stylized figure with a long tail. This",
]
generated_ids = model.generate(**inputs, max_new_tokens=20, generation_mode="text", do_sample=False)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_model_text_generation_with_multi_image(self):
model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(self.model_id)
image_1 = Image.open(
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
)
image_2 = Image.open(
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
)
prompt = "What do these two images <image_placeholder> and <image_placeholder> have in common?"
inputs = processor(images=[image_1, image_2], text=prompt, generation_mode="text", return_tensors="pt").to(
model.device, torch.float16
)
EXPECTED_TEXT_COMPLETION = ['You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\nWhat do these two images and have in common?\n\nThe two images you provided are of the same constellation. The first image shows the constellation of Leo, and the second image shows the constellation of Ursa Major. Both constellations are part of'] # fmt: skip
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_model_generate_images(self):
model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(self.model_id)
inputs = processor(
text=["A portrait of young girl. masterpiece, film grained, best quality."],
padding=True,
generation_mode="image",
return_tensors="pt",
).to(model.device)
self.assertTrue(inputs.input_ids.shape[1] == 17)
out = model.generate(
**inputs,
generation_mode="image",
do_sample=False,
)
# It should run for num_image_tokens in this case 576.
self.assertTrue(out.shape[1] == 576)
# fmt: off
expected_tokens = torch.tensor([4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971,
14985, 14834, 15438, 7548, 1820, 1465, 13529, 12761, 10503, 12761,
14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713,
14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676,
]).to(model.device)
# fmt: on
# Compare the first 50 generated tokens.
self.assertTrue(torch.allclose(expected_tokens, out[0][:50]))
# Decode generated tokens to pixel values and postprocess them.
decoded_pixel_values = model.decode_image_tokens(out)
images = processor.postprocess(list(decoded_pixel_values.float()), return_tensors="np")
self.assertTrue(images["pixel_values"].shape == (1, 384, 384, 3))
self.assertTrue(isinstance(images["pixel_values"], np.ndarray))

View File

@ -0,0 +1,455 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Janus model."""
import tempfile
import unittest
import numpy as np
from transformers import AutoProcessor, AutoTokenizer, JanusProcessor
from transformers.models.janus.convert_janus_weights_to_hf import CHAT_TEMPLATE
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
pass
class JanusProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = JanusProcessor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
special_image_tokens = {
"image_token": "<image_placeholder>",
"boi_token": "<begin_of_image>",
"eoi_token": "<end_of_image>",
}
processor = self.processor_class.from_pretrained(
"deepseek-community/Janus-Pro-1B",
extra_special_tokens=special_image_tokens,
)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def prepare_processor_dict(self):
# similar to Emu3 and Qwen2VLProcessorTest, but keep the template in the convert script to avoid duplicated code
return {
"chat_template": CHAT_TEMPLATE,
}
def test_chat_template_single(self):
"""
Tests that the chat template matches the original implementation when applied to a single message.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
# Single image message
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
]
correct_prompt = ["<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:"]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(formatted_prompt, correct_prompt)
# Single image message with capitalization
messages = [
[
{
"role": "User",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
]
correct_prompt = ["<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:"]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(formatted_prompt, correct_prompt)
# Single image message with uppercase
messages = [
[
{
"role": "USER",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
]
correct_prompt = ["<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:"]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(formatted_prompt, correct_prompt)
"""
Warning: normally, the other models have a test comparing chat template+tokenization as two separate steps
versus as a single step (i.e. processor.apply_chat_template(..., tokenize=True)). However, our processor has
some extra steps other than simply applying prompt to tokenizer. These include prepending the default system
prompts and, following the implementation from the Janus codebase, expanding the image token.
"""
# Checking the output dict keys
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# Now test the ability to return dict
messages[0][0]["content"][1].update(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(self.images_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1)
self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 1)
# Passing generation prompt explicitly
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": ""},
],
},
]
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=False)
self.assertEqual(formatted_prompt, correct_prompt)
# Single prompt with multiple images
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Compare this image"},
{"type": "image"},
{"type": "text", "text": "with this image"},
{"type": "image"},
],
},
]
]
correct_prompt = [
"<|User|>: Compare this image\n<image_placeholder>\nwith this image\n<image_placeholder>\n\n<|Assistant|>:"
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(formatted_prompt, correct_prompt)
# Multiple turns and multiple images
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Compare this image"},
{"type": "image"},
{"type": "text", "text": "with this image"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The first image is an equation, the second is a pie chart."},
],
},
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": "What about this third image? To which of the previous to is it more similar?",
},
],
},
]
]
correct_prompt = [
"<|User|>: Compare this image\n<image_placeholder>\nwith this image\n<image_placeholder>\n\n<|Assistant|>: The first image is an equation, the second is a pie chart.<end▁of▁sentence><|User|>: <image_placeholder>\nWhat about this third image? To which of the previous to is it more similar?\n\n<|Assistant|>:"
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(formatted_prompt, correct_prompt)
def test_chat_template_batched(self):
"""
Tests that the chat template matches the original implementation when applied to a batch of messages.
"""
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
# Test 1: Simple single image per message batch
batched_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
],
]
correct_prompts = [
"<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:",
"<|User|>: What is shown in this image?\n<image_placeholder>\n\n<|Assistant|>:",
]
formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
self.assertEqual(formatted_prompts, correct_prompts)
# Similarly to the single case, no test for chat template+tokenization as two separate steps versus as a single step
# Checking the output dict keys
out_dict = processor.apply_chat_template(
batched_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
padding=True,
)
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
# Verify image inputs are included in the output dict
batched_messages[0][0]["content"][1].update(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
batched_messages[1][0]["content"][1].update(
{"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
)
out_dict = processor.apply_chat_template(
batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
)
self.assertTrue(self.images_input_name in out_dict)
self.assertEqual(len(out_dict["input_ids"]), 2) # Batch size for text
self.assertEqual(len(out_dict["attention_mask"]), 2) # Batch size for attention mask
self.assertEqual(len(out_dict[self.images_input_name]), 2) # Batch size for images
# Test 2: Two images per message batch with different prompts
batched_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Compare this image"},
{"type": "image"},
{"type": "text", "text": "with this image"},
{"type": "image"},
],
},
],
[
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Describe how the previous image compares to the following"},
{"type": "image"},
],
},
],
]
correct_prompts = [
"<|User|>: Compare this image\n<image_placeholder>\nwith this image\n<image_placeholder>\n\n<|Assistant|>:",
"<|User|>: <image_placeholder>\nDescribe how the previous image compares to the following\n<image_placeholder>\n\n<|Assistant|>:",
]
formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
self.assertEqual(formatted_prompts, correct_prompts)
# Test 3: Multi-turn conversations with multiple images
batched_messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "Compare this image"},
{"type": "image"},
{"type": "text", "text": "with this image"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The first image is an equation, the second is a pie chart."},
],
},
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": "What about this third image? To which of the previous to is it more similar?",
},
],
},
],
[
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "Describe how the previous image compares to the following"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The first image is a formula, the second is a plot."},
],
},
{
"role": "user",
"content": [
{"type": "text", "text": "Which of them is closer to the following?"},
{"type": "image"},
],
},
],
]
correct_prompts = [
"<|User|>: Compare this image\n<image_placeholder>\nwith this image\n<image_placeholder>\n\n<|Assistant|>: The first image is an equation, the second is a pie chart.<end▁of▁sentence><|User|>: <image_placeholder>\nWhat about this third image? To which of the previous to is it more similar?\n\n<|Assistant|>:",
"<|User|>: <image_placeholder>\nDescribe how the previous image compares to the following\n<image_placeholder>\n\n<|Assistant|>: The first image is a formula, the second is a plot.<end▁of▁sentence><|User|>: Which of them is closer to the following?\n<image_placeholder>\n\n<|Assistant|>:",
]
formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
self.assertEqual(formatted_prompts, correct_prompts)
def test_chat_template_accepts_processing_kwargs(self):
"""Tests that the chat template correctly handles additional processing arguments."""
# Get processor and skip if it doesn't have a chat template
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
# Create a simple text message for testing
messages = [
[
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
],
},
]
]
# Test 1: Padding to max_length
# PS: we have to override the parent max_length of 50 to 80 because the output is already 51 tokens
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
padding="max_length",
max_length=80,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 80)
# Test 2: Truncation
# Verify that the output is truncated to exactly 5 tokens
formatted_prompt_tokenized = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
truncation=True,
max_length=5,
)
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
# Test 3: Image processing kwargs
# Add an image and test image processing parameters
messages[0][0]["content"].append(
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
)
# Process with image rescaling and verify the pixel values are negative
out_dict = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
do_rescale=True,
rescale_factor=-1,
return_tensors="np",
)
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
def test_processor_postprocess(self):
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
input_str = "lower newer"
orig_image_input = self.prepare_image_inputs()
orig_image = np.array(orig_image_input).transpose(2, 0, 1)
inputs = processor(text=input_str, images=orig_image, do_resize=False, return_tensors="np")
normalized_image_input = inputs.pixel_values
unnormalized_images = processor.postprocess(normalized_image_input, return_tensors="np")["pixel_values"]
# For an image where pixels go from 0 to 255 the diff can be 1 due to some numerical precision errors when scaling and unscaling
self.assertTrue(np.abs(orig_image - unnormalized_images).max() >= 1)

View File

@ -156,6 +156,7 @@ IGNORE_NON_TESTED = (
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
"JanusVisionModel", # Building part of bigger (tested) model
"TimesFmModel", # Building part of bigger (tested) model
]
)
@ -356,6 +357,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"MoshiForConditionalGeneration", # no auto class for speech-to-speech
"Emu3VQVAE", # no autoclass for VQ-VAE models
"Emu3TextModel", # Building part of bigger (tested) model
"JanusVQVAE", # no autoclass for VQ-VAE models
"JanusVisionModel", # Building part of bigger (tested) model
"Qwen2_5OmniTalkerForConditionalGeneration", # Building part of a bigger model
"Qwen2_5OmniTalkerModel", # Building part of a bigger model
"Qwen2_5OmniThinkerForConditionalGeneration", # Building part of a bigger model