diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1827894f0dd..b3f068ab2ba 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -953,6 +953,8 @@ title: InstructBLIP - local: model_doc/instructblipvideo title: InstructBlipVideo + - local: model_doc/janus + title: Janus - local: model_doc/kosmos-2 title: KOSMOS-2 - local: model_doc/layoutlm diff --git a/docs/source/en/model_doc/janus.md b/docs/source/en/model_doc/janus.md new file mode 100644 index 00000000000..015f2910dfe --- /dev/null +++ b/docs/source/en/model_doc/janus.md @@ -0,0 +1,230 @@ + + +# Janus + +## Overview + +The Janus Model was originally proposed in [Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation](https://arxiv.org/abs/2410.13848) by DeepSeek AI team and later refined in [Janus-Pro: Unified Multimodal Understanding and Generation with Data and Model Scaling](https://arxiv.org/abs/2501.17811). Janus is a vision-language model that can generate both image and text output, it can also take both images and text as input. + +> [!NOTE] +> The model doesn't generate both images and text in an interleaved format. The user has to pass a parameter indicating whether to generate text or image. + +The abstract from the original paper is the following: + +*In this paper, we introduce Janus, an autoregressive framework that unifies multimodal understanding and generation. Prior research often relies on a single visual encoder for both tasks, such as Chameleon. However, due to the differing levels of information granularity required by multimodal understanding and generation, this approach can lead to suboptimal performance, particularly in multimodal understanding. To address this issue, we decouple visual encoding into separate pathways, while still leveraging a single, unified transformer architecture for processing. The decoupling not only alleviates the conflict between the visual encoder's roles in understanding and generation, but also enhances the framework's flexibility. For instance, both the multimodal understanding and generation components can independently select their most suitable encoding methods. Experiments show that Janus surpasses previous unified model and matches or exceeds the performance of task-specific models. The simplicity, high flexibility, and effectiveness of Janus make it a strong candidate for next-generation unified multimodal models.* + +The abstract from the aforementioned `Janus-Pro` paper, released afterwards, is the following: + +*In this work, we introduce Janus-Pro, an advanced version of the previous work Janus. Specifically, Janus-Pro incorporates (1) an optimized training strate (2) expanded training data, and (3) scaling to larger model size. With these improvements, Janus-Pro achieves significant advancements in both multimodal understanding and text-to-image instruction-following capabilities, while also enhancing the stability of text-to-image generation. We hope this work will inspire further exploration in the field. Code and models are publicly available.* + +This model was contributed by [Yaswanth Gali](https://huggingface.co/yaswanthgali) and [Hugo Silva](https://huggingface.co/hugosilva664). +The original code can be found [here](https://github.com/deepseek-ai/Janus). + +## Usage Example + +### Single image inference + +Here is the example of visual understanding with a single image. + +> [!NOTE] +> Note that the model has been trained with a specific prompt format for chatting. Use `processor.apply_chat_template(my_conversation_dict)` to correctly format your prompts. + +```python +import torch +from PIL import Image +import requests + +from transformers import JanusForConditionalGeneration, JanusProcessor + +model_id = "deepseek-community/Janus-Pro-1B" +# Prepare Input for generation. +messages = [ + { + "role": "user", + "content": [ + {'type':'image', 'url': 'http://images.cocodataset.org/val2017/000000039769.jpg'}, + {'type':"text", "text":"What do you see in this image?."} + ] + }, +] + +# Set generation mode to `text` to perform text generation. +processor = JanusProcessor.from_pretrained(model_id) +model = JanusForConditionalGeneration.from_pretrained(model_id, + torch_dtype=torch.bfloat16, + device_map="auto") + +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + generation_mode="text", + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device, dtype=torch.bfloat16) + +output = model.generate(**inputs, max_new_tokens=40,generation_mode='text',do_sample=True) +text = processor.decode(output[0], skip_special_tokens=True) +print(text) +``` + +### Multi image inference + +Janus can perform inference with multiple images as input, where images can belong to the same prompt or different prompts in batched inference, where the model processes many conversations in parallel. Here is how you can do it: + +```python +import torch +from PIL import Image +import requests + +from transformers import JanusForConditionalGeneration, JanusProcessor + +model_id = "deepseek-community/Janus-Pro-1B" + +image_urls = [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "https://www.ilankelman.org/stopsigns/australia.jpg", + "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" +] + +messages = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What’s the difference between"}, + {"type": "image", "url": image_urls[0]}, + {"type": "text", "text": " and "}, + {"type": "image", "url": image_urls[1]} + ] + } + ], + [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_urls[2]}, + {"type": "text", "text": "What do you see in this image?"} + ] + } + ] +] + +# Load model and processor +processor = JanusProcessor.from_pretrained(model_id) +model = JanusForConditionalGeneration.from_pretrained( + model_id, torch_dtype=torch.bfloat16, device_map="auto" +) + +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + generation_mode="text", + tokenize=True, + padding=True, + return_dict=True, + return_tensors="pt" +).to(model.device, dtype=torch.bfloat16) + +# Generate response +output = model.generate(**inputs, max_new_tokens=40, generation_mode='text', do_sample=False) +text = processor.batch_decode(output, skip_special_tokens=True) +print(text) +``` + +## Text to Image generation + +Janus can also generate images given a prompt. + +```python +import torch +from transformers import JanusForConditionalGeneration, JanusProcessor + +# Set generation mode to `image` to prepare inputs for image generation.. + +model_id = "deepseek-community/Janus-Pro-1B" +processor = JanusProcessor.from_pretrained(model_id) +model = JanusForConditionalGeneration.from_pretrained(model_id, + torch_dtype=torch.bfloat16, + device_map="auto") + +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "A dog running under the rain."}, + ], + } +] + +prompt = processor.apply_chat_template(messages, add_generation_prompt=True) +inputs = processor(text=prompt,generation_mode="image",return_tensors="pt").to(model.device, dtype=torch.bfloat16) + +# Set num_return_sequence parameter to generate multiple images per prompt. +model.generation_config.num_return_sequences = 2 +outputs = model.generate(**inputs, + generation_mode="image", + do_sample=True, + use_cache=True, + ) +# Perform post-processing on the generated token ids. +decoded_image = model.decode_image_tokens(outputs) +images = processor.postprocess(list(decoded_image.float()),return_tensors="PIL.Image.Image") +# Save the image +for i, image in enumerate(images['pixel_values']): + image.save(f"result{i}.png") +``` + +## JanusConfig + +[[autodoc]] JanusConfig + +## JanusVisionConfig + +[[autodoc]] JanusVisionConfig + +## JanusVQVAEConfig + +[[autodoc]] JanusVQVAEConfig + +## JanusProcessor + +[[autodoc]] JanusProcessor + +## JanusImageProcessor + +[[autodoc]] JanusImageProcessor + +## JanusVisionModel + +[[autodoc]] JanusVisionModel + - forward + +## JanusVQVAE + +[[autodoc]] JanusVQVAE + - forward + +## JanusModel + +[[autodoc]] JanusModel + - forward + +## JanusForConditionalGeneration + +[[autodoc]] JanusForConditionalGeneration + - forward diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 94a68374cb7..e0478603b27 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -144,6 +144,7 @@ if TYPE_CHECKING: from .instructblip import * from .instructblipvideo import * from .jamba import * + from .janus import * from .jetmoe import * from .kosmos2 import * from .layoutlm import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c28ec163d1c..73192931338 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -163,6 +163,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("instructblip", "InstructBlipConfig"), ("instructblipvideo", "InstructBlipVideoConfig"), ("jamba", "JambaConfig"), + ("janus", "JanusConfig"), ("jetmoe", "JetMoeConfig"), ("jukebox", "JukeboxConfig"), ("kosmos-2", "Kosmos2Config"), @@ -517,6 +518,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("instructblip", "InstructBLIP"), ("instructblipvideo", "InstructBlipVideo"), ("jamba", "Jamba"), + ("janus", "Janus"), ("jetmoe", "JetMoe"), ("jukebox", "Jukebox"), ("kosmos-2", "KOSMOS-2"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 296c3dad10d..10ee95475ed 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -101,6 +101,7 @@ else: ("imagegpt", ("ImageGPTImageProcessor",)), ("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")), ("instructblipvideo", ("InstructBlipVideoImageProcessor",)), + ("janus", ("JanusImageProcessor")), ("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")), ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")), ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index af832ee2393..1746d97fd0f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -152,6 +152,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("imagegpt", "ImageGPTModel"), ("informer", "InformerModel"), ("jamba", "JambaModel"), + ("janus", "JanusModel"), ("jetmoe", "JetMoeModel"), ("jukebox", "JukeboxModel"), ("kosmos-2", "Kosmos2Model"), @@ -359,6 +360,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ("idefics", "IdeficsForVisionText2Text"), ("idefics2", "Idefics2ForConditionalGeneration"), ("idefics3", "Idefics3ForConditionalGeneration"), + ("janus", "JanusForConditionalGeneration"), ("layoutlm", "LayoutLMForMaskedLM"), ("llava", "LlavaForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), @@ -858,6 +860,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( ("idefics2", "Idefics2ForConditionalGeneration"), ("idefics3", "Idefics3ForConditionalGeneration"), ("instructblip", "InstructBlipForConditionalGeneration"), + ("janus", "JanusForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("llama4", "Llama4ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 9213b23ced1..c55a4ab2129 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -75,6 +75,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("idefics3", "Idefics3Processor"), ("instructblip", "InstructBlipProcessor"), ("instructblipvideo", "InstructBlipVideoProcessor"), + ("janus", "JanusProcessor"), ("kosmos-2", "Kosmos2Processor"), ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 102a48f3ad4..8496588acb0 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -265,6 +265,7 @@ else: "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ("janus", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)), ( "jetmoe", ( diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 05e01c37bbd..1c83ddea5a7 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -755,7 +755,6 @@ class ChameleonVQVAEVectorQuantizer(nn.Module): self.beta = getattr(config, "beta", 0.25) self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) - self.re_embed = self.num_embeddings def forward(self, hidden_state: torch.Tensor): hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() diff --git a/src/transformers/models/janus/__init__.py b/src/transformers/models/janus/__init__.py new file mode 100644 index 00000000000..06bc90cd938 --- /dev/null +++ b/src/transformers/models/janus/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_janus import * + from .image_processing_janus import * + from .modeling_janus import * + from .processing_janus import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/janus/configuration_janus.py b/src/transformers/models/janus/configuration_janus.py new file mode 100644 index 00000000000..de727ab6d07 --- /dev/null +++ b/src/transformers/models/janus/configuration_janus.py @@ -0,0 +1,314 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/janus/modular_janus.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_janus.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING, AutoConfig + + +logger = logging.get_logger(__name__) + + +class JanusVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`JanusVisionModel`]. It is used to instantiate a + `JanusVisionModel` according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + attention_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for attention weights. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"`, and `"gelu_new"` are supported. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys, and values in the attention layers. + hidden_dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for fully connected layers in the encoder. + projection_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the MLP projection head. + projection_dropout (`float`, *optional*, defaults to 0.0): + Dropout probability for the projection layer. + use_qk_norm (`bool`, *optional*, defaults to `False`): + Whether to normalize the query and key matrices. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated normal initializer for initializing all weight matrices. + depth (`int`, *optional*, defaults to 2): + Number of hidden layers in the aligner module. + num_image_tokens (`int`, *optional*, defaults to 576): + Number of image tokens. + """ + + model_type = "janus_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + patch_size=16, + image_size=384, + attention_dropout=0.0, + layer_norm_eps=1e-6, + hidden_act="gelu", + mlp_ratio=4.0, + attention_bias=True, + hidden_dropout_rate=0.0, + projection_dim=2048, + projection_dropout=0.0, + use_qk_norm=False, + initializer_range=0.02, + depth=2, + num_image_tokens=576, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + + self.mlp_ratio = mlp_ratio + self.attention_bias = attention_bias + self.hidden_dropout_rate = hidden_dropout_rate + self.projection_dim = projection_dim + self.projection_dropout = projection_dropout + self.use_qk_norm = use_qk_norm + self.initializer_range = initializer_range + self.depth = depth + self.num_image_tokens = num_image_tokens + + +class JanusVQVAEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`JanusVQVAEModel`]. It is used to instantiate a + `JanusVQVAEModel` according to the specified arguments, defining the model architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a + configuration with the defaults will yield a similar configuration to the VQModel of the + [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B). + + Args: + embed_dim (`int`, *optional*, defaults to 8): + Dimensionality of each embedding vector. + num_embeddings (`int`, *optional*, defaults to 16384): + Number of codebook embeddings. + double_latent (`bool`, *optional*, defaults to `False`): + Whether to use double z channels. + latent_channels (`int`, *optional*, defaults to 256): + Number of channels for the latent space. + num_patches (`int`, *optional*, defaults to 32): + Num of patches the input images can be divided into. + in_channels (`int`, *optional*, defaults to 3): + Number of input channels. + out_channels (`int`, *optional*, defaults to 3): + Number of out channels. + base_channels (`int`, *optional*, defaults to 128): + Base channel count. + channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`): + Channel multipliers for each resolution. + num_res_blocks (`int`, *optional*, defaults to 2): + Number of residual blocks. + dropout (`float`, *optional*, defaults to 0.0): + Dropout rate. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + projection_dim (`int`, *optional*, defaults to 2048): + Dimensionality of the MLP projection head. + num_hidden_layers (`int`, *optional*, defaults to 2): + Number of hidden layers in VAVAE MLP Connecter module. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + image_token_embed_dim (`int`, *optional*, defaults to 2048): + Dimension of image embeddings. It should be same as the dimensionality of text embeddings. + """ + + model_type = "janus_vqgan" + base_config_key = "vq_config" + + def __init__( + self, + embed_dim: int = 8, + num_embeddings: int = 16384, + double_latent: bool = False, + latent_channels: int = 256, + num_patches: int = 32, + in_channels: int = 3, + out_channels: int = 3, + base_channels: int = 128, + channel_multiplier: List[int] = [1, 1, 2, 2, 4], + num_res_blocks: int = 2, + dropout: float = 0.0, + initializer_range=0.02, + projection_dim=2048, + num_hidden_layers=2, + hidden_act="gelu", + image_token_embed_dim=2048, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.num_embeddings = num_embeddings + self.double_latent = double_latent + self.latent_channels = latent_channels + self.in_channels = in_channels + self.base_channels = base_channels + self.channel_multiplier = channel_multiplier + self.num_res_blocks = num_res_blocks + self.dropout = dropout + self.initializer_range = initializer_range + self.num_patches = num_patches + self.out_channels = out_channels + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.hidden_act = hidden_act + self.image_token_embed_dim = image_token_embed_dim + + +class JanusConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`JanusModel`]. It is used to instantiate an + Janus model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Janus-1B or Janus-7B models. + + e.g. [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B) or + [deepseek-community/Janus-Pro-7B](https://huggingface.co/deepseek-community/Janus-Pro-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVisionConfig`): + The config object or dictionary of the vision backbone. + vq_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVQVAEConfig`): + The config object or dictionary of the VQVAE backbone. + + Example: + + ```python + >>> from transformers import JanusForConditionalGeneration, JanusConfig, JanusVisionConfig, JanusVQVAEConfig, LlamaConfig + + >>> # Initializing a Janus vision config + >>> vision_config = JanusVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a VQ config + >>> vq_config = JanusVQVAEConfig() + + >>> # Initializing a Janus Pro 1B style configuration + >>> configuration = JanusConfig(vision_config=vision_config, text_config=text_config, vq_config=vq_config) + + >>> # Initializing a model from the Janus Pro 1B style configuration + >>> model = JanusForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "janus" + sub_configs = { + "text_config": AutoConfig, + "vision_config": JanusVisionConfig, + "vq_config": JanusVQVAEConfig, + } + + def __init__(self, text_config=None, vision_config=None, vq_config=None, **kwargs): + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "llama") + self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + + elif text_config is None: + logger.info("`text_config` is None. Initializing with default values") + self.text_config = CONFIG_MAPPING["llama"]() + elif isinstance(text_config, PretrainedConfig): + self.text_config = text_config + else: + raise ValueError( + f"Invalid type for `text_config`. Must be either `dict` or `LlamaConfig`." + f" Type found: {type(text_config)}" + ) + + if vision_config is None: + logger.info("`vision_config` is None. Initializing with default JanusVisionConfig values") + self.vision_config = JanusVisionConfig() + elif isinstance(vision_config, dict): + self.vision_config = JanusVisionConfig(**vision_config) + elif isinstance(vision_config, JanusVisionConfig): + self.vision_config = vision_config + else: + raise ValueError( + f"Invalid type for `vision_config`. Must be either `dict` or `JanusVisionConfig`." + f" Type found: {type(vision_config)}" + ) + + if vq_config is None: + logger.info("`vq_config` is None. Initializing with default JanusVQVAEConfig values") + self.vq_config = JanusVQVAEConfig() + elif isinstance(vq_config, dict): + self.vq_config = JanusVQVAEConfig(**vq_config) + elif isinstance(vq_config, JanusVQVAEConfig): + self.vq_config = vq_config + else: + raise ValueError( + f"Invalid type for `vq_config`. Must be either `dict` or `JanusVQVAEConfig`." + f" Type found: {type(vq_config)}" + ) + + # This dimension is required when decoding discrete image tokens to continuous input. + self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size + # The default is only the index for the 1B model, 7B uses a different one + self.image_token_index = kwargs.get("image_token_index", 100581) + super().__init__(**kwargs) + + +__all__ = ["JanusVQVAEConfig", "JanusVisionConfig", "JanusConfig"] diff --git a/src/transformers/models/janus/convert_janus_weights_to_hf.py b/src/transformers/models/janus/convert_janus_weights_to_hf.py new file mode 100644 index 00000000000..32e16780bbe --- /dev/null +++ b/src/transformers/models/janus/convert_janus_weights_to_hf.py @@ -0,0 +1,501 @@ +# coding=utf-8 +# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example of run command (run from root): + +python src/transformers/models/janus/convert_janus_weights_to_hf.py --repo_id deepseek-ai/Janus-Pro-1B --local_dir tmp/hub_code_in --output_dir tmp/hub_code_out --safe_serialization +Using provided local directory: tmp/hub_code_in +""" + +import argparse +import gc +import json +import os +import re + +import torch +from accelerate import init_empty_weights +from huggingface_hub import snapshot_download + +from transformers import ( + AutoTokenizer, + JanusConfig, + JanusForConditionalGeneration, + JanusVisionConfig, + JanusVQVAEConfig, + LlamaConfig, +) +from transformers.models.janus.image_processing_janus import JanusImageProcessor +from transformers.models.janus.processing_janus import JanusProcessor + + +# Mappings +MAPPINGS = { + # Vision model + r"(?\b(vision_model|model\.vision_model)\b.*\.)proj(?=\.|\s|$)": r"\g
projection_layer",
+    r"(?P
\b(vision_model|model\.vision_model)\b.*\.)norm(?=\.|\s|$)": r"\g
layer_norm",
+    r"(?P
\b(vision_model|model\.vision_model)\b.*\.)norm1(?=\.|\s|$)": r"\g
layer_norm1",
+    r"(?P
\b(vision_model|model\.vision_model)\b.*\.)norm2(?=\.|\s|$)": r"\g
layer_norm2",
+    r"\bvision_model\.vision_tower\.attn_pool\.[^\s$]*": None,
+    # VQ Model
+    r"gen_vision_model": "model.vqmodel",
+    r"(?P
\b(gen_vision_model|model\.vqmodel)\b.*\.)decoder\.conv_blocks(?=\.|\s|$)": r"\g
decoder.up",
+    r"(?P
\b(gen_vision_model|model\.vqmodel)\b.*\.)encoder\.conv_blocks(?=\.|\s|$)": r"\g
encoder.down",
+    r"(?P
\b(gen_vision_model|model\.vqmodel)\b.*\.)res(?=\.|\s|$)": r"\g
block",
+    r"(?P
\b(gen_vision_model|model\.vqmodel)\b.*\.)mid\.0(?=\.|\s|$)": r"\g
mid.block_1",
+    r"(?P
\b(gen_vision_model|model\.vqmodel)\b.*\.)mid\.1(?=\.|\s|$)": r"\g
mid.attn_1",
+    r"(?P
\b(gen_vision_model|model\.vqmodel)\b.*\.)mid\.2(?=\.|\s|$)": r"\g
mid.block_2",
+    # Aligner Modules
+    r"(gen_aligner)\.layers\.0": r"model.generation_aligner.fc1",
+    r"(gen_aligner)\.layers\.2": r"model.generation_aligner.hidden_layers.0",
+    r"(?']%}"
+    "{%set i=0%}"
+    "{%for message in messages%}"
+    "{%if message['role']|lower=='user'%}"
+    "<|User|>: "
+    "{%elif message['role']|lower=='assistant'%}"
+    "<|Assistant|>:{%if not (loop.last and not add_generation_prompt and message['content'][0]['type']=='text' and message['content'][0]['text']=='')%} {%endif%}"
+    "{%else%}"
+    "{{message['role'].capitalize()}}: "
+    "{%endif%}"
+    "{%for content in message['content']%}"
+    "{%if content['type']=='image'%}"
+    "{%if not loop.first%}{{'\n'}}{%endif%}"
+    ""
+    "{%if not loop.last%}{{'\n'}}{%endif%}"
+    "{%elif content['type']=='text'%}"
+    "{%set text=content['text']%}"
+    "{%if loop.first%}{%set text=text.lstrip()%}{%endif%}"
+    "{%if loop.last%}{%set text=text.rstrip()%}{%endif%}"
+    "{%if not loop.first and message['content'][loop.index0-1]['type']=='text'%}"
+    "{{' '+text}}"
+    "{%else%}"
+    "{{text}}"
+    "{%endif%}"
+    "{%endif%}"
+    "{%endfor%}"
+    "{%if not loop.last or add_generation_prompt%}"
+    "{%if message['role']|lower=='user'%}"
+    "{{seps[0]}}"
+    "{%else%}"
+    "{{seps[1]}}"
+    "{%endif%}"
+    "{%endif%}"
+    "{%endfor%}"
+    "{%if add_generation_prompt%}<|Assistant|>:{%endif%}"
+)
+
+
+def convert_old_keys_to_new_keys(state_dict):
+    keys_as_text = "\n".join(state_dict.keys())
+    new_keys_as_text = keys_as_text
+    for old, repl in MAPPINGS.items():
+        if repl is None:
+            new_keys_as_text = re.sub(old, "", new_keys_as_text)
+        else:
+            new_keys_as_text = re.sub(old, repl, new_keys_as_text)
+    output_dict = dict(zip(keys_as_text.split("\n"), new_keys_as_text.split("\n")))
+    return output_dict
+
+
+def split_tensor(tensor, key):
+    """Splits a merged tensor (qkv or kv) into separate tensors and creates keys for each part."""
+
+    if "qkv" in key:
+        prefix_to_replace = "qkv"
+        num_splits = 3
+        new_keys = ["q_proj", "k_proj", "v_proj"]
+    elif "kv" in key:
+        prefix_to_replace = "kv"
+        num_splits = 2
+        new_keys = ["k_proj", "v_proj"]
+    else:
+        raise ValueError(f"Unrecognized tensor type in key: {key}")
+
+    split_size = tensor.shape[0] // num_splits
+    tensors = torch.split(tensor, split_size, dim=0)
+    return {key.replace(prefix_to_replace, new_keys[i]): tensors[i] for i in range(num_splits)}
+
+
+def convert_state_dict_to_hf(state_dict):
+    """Convert state dict keys to HF format."""
+    conversion_dict = convert_old_keys_to_new_keys(state_dict)
+    converted_state_dict = {}
+
+    for old_key, new_key in conversion_dict.items():
+        if new_key:
+            if "qkv" in new_key or "kv" in new_key:  # Detect merged attention keys and split them.
+                qkv_split_dict = split_tensor(state_dict[old_key], new_key)
+                converted_state_dict.update(qkv_split_dict)
+            else:
+                converted_state_dict[new_key] = state_dict[old_key]
+
+    # Embeddings will not have initial dimension
+    pos_embed_key = "model.vision_model.embeddings.position_embedding.weight"
+    converted_state_dict[pos_embed_key] = converted_state_dict[pos_embed_key].squeeze(0)
+
+    return converted_state_dict
+
+
+def ensure_model_downloaded(repo_id: str = None, revision: str = None, local_dir: str = None) -> str:
+    """
+    Ensures model files are downloaded locally, downloads them if not.
+    Returns path to local files.
+
+    Args:
+        repo_id: The Hugging Face model repo ID (required if local_dir not provided)
+        revision: Optional git revision to use
+        local_dir: Optional local directory path where model files should be stored/found
+    """
+    if local_dir is not None:
+        if os.path.exists(local_dir):
+            print(f"Using provided local directory: {local_dir}")
+        else:
+            # Create the local directory if it doesn't exist
+            os.makedirs(local_dir, exist_ok=True)
+            print(f"Created local directory: {local_dir}")
+
+    if repo_id is None:
+        raise ValueError("Either repo_id or local_dir must be provided")
+
+    print(f"Ensuring {repo_id} (revision: {revision or 'latest'}) is downloaded...")
+
+    try:
+        # First try to find files locally
+        download_dir = snapshot_download(repo_id, revision=revision, local_files_only=True, local_dir=local_dir)
+        print(f"Found model files locally at {download_dir}")
+        return download_dir
+    except Exception:
+        # If files not found locally, download them
+        print(f"Downloading model files for {repo_id}...")
+        download_dir = snapshot_download(repo_id, revision=revision, local_files_only=False, local_dir=local_dir)
+        print(f"Downloaded model files to {download_dir}")
+        return download_dir
+
+
+def load_model_state_dict(input_path: str) -> dict:
+    """
+    Load model state dict, handling both single and sharded files.
+    """
+    index_path = os.path.join(input_path, "pytorch_model.bin.index.json")
+    single_file_path = os.path.join(input_path, "pytorch_model.bin")
+
+    # Check if we have a sharded model
+    if os.path.exists(index_path):
+        print("Loading sharded model...")
+        state_dict = {}
+        with open(index_path, "r") as f:
+            index = json.load(f)
+
+        # Get unique shard files and load each one only once
+        unique_shard_files = sorted(set(index["weight_map"].values()))
+        for shard_file in unique_shard_files:
+            print(f"Loading shard {shard_file}...")
+            shard_path = os.path.join(input_path, shard_file)
+            shard_dict = torch.load(shard_path, map_location="cpu")
+            state_dict.update(shard_dict)
+
+        return state_dict
+
+    # Single file model
+    elif os.path.exists(single_file_path):
+        print("Loading single file model...")
+        return torch.load(single_file_path, map_location="cpu")
+
+    else:
+        raise ValueError(f"No model files found in {input_path}")
+
+
+def convert_model(
+    repo_id=None,
+    local_dir=None,
+    text_model_id=None,
+    output_dir=None,
+    output_hub_path=None,
+    safe_serialization=True,
+    revision=None,
+):
+    """Convert and save the model weights, processor, and configuration."""
+    if output_dir is None and output_hub_path is None:
+        raise ValueError("At least one of output_dir or output_hub_path must be specified")
+
+    if repo_id is None and local_dir is None:
+        raise ValueError("Either repo_id or local_dir must be specified")
+
+    # Create output directory if specified
+    if output_dir:
+        os.makedirs(output_dir, exist_ok=True)
+        print(f"Created/verified output directory: {output_dir}")
+
+    torch.set_default_dtype(torch.float16)
+
+    # Download or locate model files
+    input_path = ensure_model_downloaded(repo_id=repo_id, revision=revision, local_dir=local_dir)
+
+    # Load configuration files
+    required_files = ["config.json", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"]
+
+    missing_files = [f for f in required_files if not os.path.exists(os.path.join(input_path, f))]
+    if missing_files:
+        raise ValueError(
+            f"The following required configuration files are missing from {input_path}: {', '.join(missing_files)}. "
+            "Please ensure you have downloaded all necessary model files."
+        )
+
+    with open(os.path.join(input_path, "config.json"), "r") as f:
+        config_data = json.load(f)
+    with open(os.path.join(input_path, "preprocessor_config.json"), "r") as f:
+        preprocessor_config = json.load(f)
+    with open(os.path.join(input_path, "special_tokens_map.json"), "r") as f:
+        special_tokens_map = json.load(f)
+    with open(os.path.join(input_path, "tokenizer_config.json"), "r") as f:
+        tokenizer_config = json.load(f)
+
+    # Create tokenizer directly from tokenizer.json if it exists
+    tokenizer_json_path = os.path.join(input_path, "tokenizer.json")
+    special_image_tokens = {
+        "image_token": "",
+        "boi_token": "",
+        "eoi_token": "",
+    }
+
+    if os.path.exists(tokenizer_json_path) and not text_model_id:
+        tokenizer = AutoTokenizer.from_pretrained(
+            input_path,  # This will load tokenizer.json directly
+            model_max_length=tokenizer_config["model_max_length"],
+            extra_special_tokens=special_image_tokens,
+        )
+    else:
+        # Fallback to creating from text_model_id with special tokens
+        tokenizer = AutoTokenizer.from_pretrained(
+            text_model_id,
+            bos_token=special_tokens_map["bos_token"],
+            eos_token=special_tokens_map["eos_token"],
+            pad_token=special_tokens_map["pad_token"],
+            additional_special_tokens=special_tokens_map["additional_special_tokens"],
+            model_max_length=tokenizer_config["model_max_length"],
+            extra_special_tokens=special_image_tokens,
+        )
+
+    # Create image processor from config
+    image_processor_kwargs = {}
+    for key in ["do_normalize", "image_mean", "image_std", "min_size", "rescale_factor"]:
+        if key in preprocessor_config:
+            image_processor_kwargs[key] = preprocessor_config[key]
+
+    if "image_size" in preprocessor_config:
+        image_processor_kwargs["size"] = {
+            "height": preprocessor_config["image_size"],
+            "width": preprocessor_config["image_size"],
+        }
+
+    image_processor = JanusImageProcessor(**image_processor_kwargs)
+
+    # Create processor with chat template
+    processor = JanusProcessor(
+        image_processor=image_processor,
+        tokenizer=tokenizer,
+        chat_template=CHAT_TEMPLATE,
+        use_default_system_prompt=True,
+    )
+
+    if output_dir:
+        print(f"Saving processor to {output_dir}...")
+        processor.save_pretrained(output_dir)
+    if output_hub_path:
+        print(f"Pushing processor to hub at {output_hub_path}...")
+        processor.push_to_hub(output_hub_path)
+
+    # Create model configurations
+    text_config_kwargs = {}
+    for key in [
+        "vocab_size",
+        "hidden_size",
+        "intermediate_size",
+        "num_hidden_layers",
+        "num_attention_heads",
+        "num_key_value_heads",
+        "hidden_act",
+        "max_position_embeddings",
+        "torch_dtype",
+    ]:
+        if key in config_data["language_config"]:
+            text_config_kwargs[key] = config_data["language_config"][key]
+
+    # Add token IDs from tokenizer
+    text_config_kwargs.update(
+        {
+            "pad_token_id": tokenizer.pad_token_id,
+            "bos_token_id": tokenizer.bos_token_id,
+            "eos_token_id": tokenizer.eos_token_id,
+        }
+    )
+
+    text_config = LlamaConfig(**text_config_kwargs)
+
+    # Create vision config
+    vision_config_kwargs = {}
+    if "image_size" in config_data["vision_config"]["params"]:
+        vision_config_kwargs["image_size"] = config_data["vision_config"]["params"]["image_size"]
+
+    # Add aligner params if present
+    if "aligner_config" in config_data and "params" in config_data["aligner_config"]:
+        if "n_embed" in config_data["aligner_config"]["params"]:
+            vision_config_kwargs["projection_dim"] = config_data["aligner_config"]["params"]["n_embed"]
+        if "depth" in config_data["aligner_config"]["params"]:
+            vision_config_kwargs["depth"] = config_data["aligner_config"]["params"]["depth"]
+
+    vision_config = JanusVisionConfig(**vision_config_kwargs)
+
+    vq_config = JanusVQVAEConfig(
+        embed_dim=config_data["gen_vision_config"]["params"]["n_embed"],
+        num_embeddings=config_data["gen_vision_config"]["params"]["image_token_size"],
+        projection_dim=config_data["gen_aligner_config"]["params"]["n_embed"],
+        depth=config_data["gen_aligner_config"]["params"]["depth"],
+        image_token_embed_dim=config_data["gen_head_config"]["params"]["image_token_embed"],
+    )
+
+    # Create the main config
+    config = JanusConfig(
+        text_config=text_config,
+        vision_config=vision_config,
+        vq_config=vq_config,
+        image_token_index=tokenizer.vocab.get(""),
+    )
+
+    # Save the config
+    if output_dir:
+        config.save_pretrained(output_dir)
+    if output_hub_path:
+        config.push_to_hub(output_hub_path)
+
+    # Initialize model with empty weights
+    print("Creating empty model...")
+    with init_empty_weights():
+        model = JanusForConditionalGeneration(config)
+
+    model.generation_config.temperature = 1
+    model.generation_config.guidance_scale = 5
+    model.generation_config.pad_token_id = tokenizer.vocab.get("<\uff5c\u2581pad\u2581\uff5c>")
+    model.generation_config.generation_kwargs["boi_token_id"] = tokenizer.vocab.get("")
+
+    # Load and convert state dict
+    print("Loading state dict...")
+    state_dict = load_model_state_dict(input_path)
+    state_dict = convert_state_dict_to_hf(state_dict)
+
+    # Load converted state dict
+    print("Loading converted weights into model...")
+    model.load_state_dict(state_dict, strict=True, assign=True)
+
+    # Tie weights before any device mapping
+    print("Tying weights...")
+    model.tie_weights()
+
+    # Save the model
+    if output_dir:
+        print(f"Saving model to {output_dir}...")
+        model.save_pretrained(output_dir, safe_serialization=safe_serialization)
+    if output_hub_path:
+        print(f"Pushing model to hub at {output_hub_path}...")
+        model.push_to_hub(output_hub_path, safe_serialization=safe_serialization)
+
+    del state_dict, model
+    gc.collect()
+
+    # Validate the saved model if saved locally
+    if output_dir:
+        print("Reloading the local model to check if it's saved correctly...")
+        # TODO: warning about weights not being tied is raised here regardless of model.tie_weights() above
+        JanusForConditionalGeneration.from_pretrained(output_dir, device_map="auto")
+        print("Local model reloaded successfully.")
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--repo_id",
+        help="HuggingFace Hub repo ID for the model",
+        default=None,
+    )
+    parser.add_argument(
+        "--local_dir",
+        help="Local directory containing the model files",
+        default=None,
+    )
+    parser.add_argument(
+        "--revision",
+        help="Specific revision to download from the Hub",
+        default=None,
+    )
+    parser.add_argument(
+        "--output_dir",
+        help="Location to write HF model locally",
+        default=None,
+    )
+    parser.add_argument(
+        "--output_hub_path",
+        help="Repository ID to push model to hub (e.g. 'username/model-name')",
+        default=None,
+    )
+    parser.add_argument(
+        "--text_model_id",
+        help="Hub ID of the text model to get tokenizer from. Optional if tokenizer.json exists in the model directory.",
+        required=False,
+    )
+    parser.add_argument(
+        "--safe_serialization",
+        action="store_true",
+        help="Whether to save using safetensors",
+    )
+    args = parser.parse_args()
+
+    if args.output_dir is None and args.output_hub_path is None:
+        raise ValueError("At least one of --output_dir or --output_hub_path must be specified")
+
+    if args.repo_id is None and args.local_dir is None:
+        raise ValueError("Either --repo_id or --local_dir must be specified")
+
+    convert_model(
+        repo_id=args.repo_id,
+        local_dir=args.local_dir,
+        text_model_id=args.text_model_id,
+        output_dir=args.output_dir,
+        output_hub_path=args.output_hub_path,
+        safe_serialization=args.safe_serialization,
+        revision=args.revision,
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/src/transformers/models/janus/image_processing_janus.py b/src/transformers/models/janus/image_processing_janus.py
new file mode 100644
index 00000000000..1dac2fec481
--- /dev/null
+++ b/src/transformers/models/janus/image_processing_janus.py
@@ -0,0 +1,508 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/janus/modular_janus.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_janus.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
+from ...image_utils import (
+    OPENAI_CLIP_MEAN,
+    OPENAI_CLIP_STD,
+    ChannelDimension,
+    ImageInput,
+    PILImageResampling,
+    get_image_size,
+    infer_channel_dimension_format,
+    is_scaled_image,
+    make_flat_list_of_images,
+    make_list_of_images,
+    to_numpy_array,
+    valid_images,
+    validate_preprocess_arguments,
+)
+from ...utils import (
+    TensorType,
+    filter_out_non_signature_kwargs,
+    is_vision_available,
+    logging,
+)
+
+
+if is_vision_available():
+    import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class JanusImageProcessor(BaseImageProcessor):
+    r"""
+    Constructs a JANUS image processor.
+
+    Args:
+        do_resize (`bool`, *optional*, defaults to `True`):
+            Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+            `do_resize` parameter in the `preprocess` method.
+        size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
+            Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+            method.
+        min_size (`int`, *optional*, defaults to 14):
+            The minimum allowed size for the resized image. Ensures that neither the height nor width
+            falls below this value after resizing.
+        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+            Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+            overridden by the `resample` parameter in the `preprocess` method.
+        do_rescale (`bool`, *optional*, defaults to `True`):
+            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+            `do_rescale` parameter in the `preprocess` method.
+        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+            Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+            overridden by the `rescale_factor` parameter in the `preprocess` method.
+        do_normalize (`bool`, *optional*, defaults to `True`):
+            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+            method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+            overridden by the `image_mean` parameter in the `preprocess` method.
+        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+            Can be overridden by the `image_std` parameter in the `preprocess` method.
+        do_convert_rgb (`bool`, *optional*, defaults to `True`):
+            Whether to convert the image to RGB.
+    """
+
+    model_input_names = ["pixel_values"]
+
+    def __init__(
+        self,
+        do_resize: bool = True,
+        size: Dict[str, int] = None,
+        min_size: int = 14,
+        resample: PILImageResampling = PILImageResampling.BICUBIC,
+        do_rescale: bool = True,
+        rescale_factor: Union[int, float] = 1 / 255,
+        do_normalize: bool = True,
+        image_mean: Optional[Union[float, List[float]]] = None,
+        image_std: Optional[Union[float, List[float]]] = None,
+        do_convert_rgb: bool = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        size = size if size is not None else {"height": 384, "width": 384}
+        size = get_size_dict(size, default_to_square=True)
+
+        self.do_resize = do_resize
+        self.size = size
+        self.resample = resample
+        self.do_rescale = do_rescale
+        self.rescale_factor = rescale_factor
+        self.do_normalize = do_normalize
+        self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+        self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+        self.do_convert_rgb = do_convert_rgb
+
+        self.min_size = min_size
+        if image_mean is None:
+            self.background_color = (127, 127, 127)
+        else:
+            self.background_color = tuple([int(x * 255) for x in image_mean])
+
+    def resize(
+        self,
+        image: np.ndarray,
+        size: Union[Dict[str, int], int],
+        resample: PILImageResampling = PILImageResampling.BICUBIC,
+        data_format: Optional[Union[str, ChannelDimension]] = None,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+        **kwargs,
+    ) -> np.ndarray:
+        """
+        Resize an image to dynamically calculated size.
+
+        Args:
+            image (`np.ndarray`):
+                Image to resize.
+            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+            data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the output image. If unset, the channel dimension format of the input
+                image is used. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `None`: will be inferred from input
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+        Returns:
+            `np.ndarray`: The resized image.
+        """
+        if input_data_format is None:
+            input_data_format = infer_channel_dimension_format(image)
+
+        height, width = get_image_size(image, input_data_format)
+        max_size = max(height, width)
+
+        size = get_size_dict(size, default_to_square=True)
+        if size["height"] != size["width"]:
+            raise ValueError(
+                f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
+            )
+        size = size["height"]
+
+        delta = size / max_size
+        # Largest side becomes `size` and the other side is scaled according to the aspect ratio.
+        output_size_nonpadded = [
+            max(int(height * delta), self.min_size),
+            max(int(width * delta), self.min_size),
+        ]
+
+        image = resize(
+            image,
+            size=output_size_nonpadded,
+            resample=resample,
+            data_format=data_format,
+            input_data_format=input_data_format,
+            **kwargs,
+        )
+        # Expand and pad the images to obtain a square image of dimensions `size x size`
+        image = self.pad_to_square(
+            image=image,
+            background_color=self.background_color,
+            input_data_format=input_data_format,
+        )
+        return image
+
+    @filter_out_non_signature_kwargs()
+    def preprocess(
+        self,
+        images: ImageInput,
+        do_resize: Optional[bool] = None,
+        size: Optional[Dict[str, int]] = None,
+        resample: PILImageResampling = None,
+        do_rescale: Optional[bool] = None,
+        rescale_factor: Optional[float] = None,
+        do_normalize: Optional[bool] = None,
+        image_mean: Optional[Union[float, List[float]]] = None,
+        image_std: Optional[Union[float, List[float]]] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        do_convert_rgb: Optional[bool] = None,
+        data_format: ChannelDimension = ChannelDimension.FIRST,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+    ) -> PIL.Image.Image:
+        """
+        Preprocess an image or batch of images.
+
+        Args:
+            images (`ImageInput`):
+                Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+                passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+            do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+                Whether to resize the image.
+            size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+                Controls the size of the image after `resize`. The shortest edge of the image is resized to
+                `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
+                is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
+                edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
+            resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+                Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
+            do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+                Whether to rescale the image values between [0 - 1].
+            rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+                Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+            do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+                Whether to normalize the image.
+            image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+                Image mean to normalize the image by if `do_normalize` is set to `True`.
+            image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+                Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
+            do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+                Whether to convert the image to RGB.
+            return_tensors (`str` or `TensorType`, *optional*):
+                The type of tensors to return. Can be one of:
+                    - Unset: Return a list of `np.ndarray`.
+                    - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+                    - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+                    - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+                    - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+            data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+                The channel dimension format for the output image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - Unset: Use the channel dimension format of the input image.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+        """
+        do_resize = do_resize if do_resize is not None else self.do_resize
+        resample = resample if resample is not None else self.resample
+        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+        rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+        image_mean = image_mean if image_mean is not None else self.image_mean
+        image_std = image_std if image_std is not None else self.image_std
+        do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+        size = size if size is not None else self.size
+        size = get_size_dict(size, default_to_square=False)
+        images = make_flat_list_of_images(images)
+
+        if not valid_images(images):
+            raise ValueError(
+                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+                "torch.Tensor, tf.Tensor or jax.ndarray."
+            )
+
+        validate_preprocess_arguments(
+            do_rescale=do_rescale,
+            rescale_factor=rescale_factor,
+            do_normalize=do_normalize,
+            image_mean=image_mean,
+            image_std=image_std,
+            do_resize=do_resize,
+            size=size,
+            resample=resample,
+        )
+        # PIL RGBA images are converted to RGB
+        if do_convert_rgb:
+            images = [convert_to_rgb(image) for image in images]
+
+        # All transformations expect numpy arrays.
+        images = [to_numpy_array(image) for image in images]
+
+        if do_rescale and is_scaled_image(images[0]):
+            logger.warning_once(
+                "It looks like you are trying to rescale already rescaled images. If the input"
+                " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+            )
+
+        if input_data_format is None:
+            # We assume that all images have the same channel dimension format.
+            input_data_format = infer_channel_dimension_format(images[0])
+
+        if do_resize:
+            images = [
+                self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+                for image in images
+            ]
+
+        if do_rescale:
+            images = [
+                self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+                for image in images
+            ]
+
+        if do_normalize:
+            images = [
+                self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+                for image in images
+            ]
+
+        images = [
+            to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+        ]
+
+        encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+
+        return encoded_outputs
+
+    def pad_to_square(
+        self,
+        image: np.ndarray,
+        background_color: Union[int, Tuple[int, int, int]] = 0,
+        data_format: Optional[Union[str, ChannelDimension]] = None,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+    ) -> np.array:
+        """
+        Pads an image to a square based on the longest edge.
+
+        Args:
+            image (`np.ndarray`):
+                The image to pad.
+            background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
+                The color to use for the padding. Can be an integer for single channel or a
+                tuple of integers representing for multi-channel images. If passed as integer
+                in mutli-channel mode, it will default to `0` in subsequent channels.
+            data_format (`str` or `ChannelDimension`, *optional*):
+                The channel dimension format for the output image. Can be one of:
+                    - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                    - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                If unset, will use same as the input image.
+            input_data_format (`str` or `ChannelDimension`, *optional*):
+                The channel dimension format for the input image. Can be one of:
+                    - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                    - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+        Returns:
+            `np.ndarray`: The padded image.
+        """
+        height, width = get_image_size(image, input_data_format)
+        num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
+
+        if height == width:
+            image = (
+                to_channel_dimension_format(image, data_format, input_data_format)
+                if data_format is not None
+                else image
+            )
+            return image
+
+        max_dim = max(height, width)
+
+        # Ensure background_color is the correct shape
+        if isinstance(background_color, int):
+            background_color = [background_color]
+        elif len(background_color) != num_channels:
+            raise ValueError(
+                f"background_color must have no more than {num_channels} elements to match the number of channels"
+            )
+
+        if input_data_format == ChannelDimension.FIRST:
+            result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
+            for i, color in enumerate(background_color):
+                result[i, :, :] = color
+            if width > height:
+                start = (max_dim - height) // 2
+                result[:, start : start + height, :] = image
+            else:
+                start = (max_dim - width) // 2
+                result[:, :, start : start + width] = image
+        else:
+            result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
+            for i, color in enumerate(background_color):
+                result[:, :, i] = color
+            if width > height:
+                start = (max_dim - height) // 2
+                result[start : start + height, :, :] = image
+            else:
+                start = (max_dim - width) // 2
+                result[:, start : start + width, :] = image
+
+        return result
+
+    def postprocess(
+        self,
+        images: ImageInput,
+        do_rescale: bool = None,
+        rescale_factor: float = None,
+        do_normalize: bool = None,
+        image_mean: List[float] = None,
+        image_std: List[float] = None,
+        input_data_format: str = None,
+        return_tensors: str = None,
+    ):
+        """Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing."""
+        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+        rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor
+        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+        image_mean = image_mean if image_mean is not None else self.image_mean
+        image_std = image_std if image_std is not None else self.image_std
+
+        images = make_list_of_images(images)  # Ensures input is a list
+
+        if isinstance(images[0], PIL.Image.Image):
+            return images if len(images) > 1 else images[0]
+
+        if input_data_format is None:
+            input_data_format = infer_channel_dimension_format(images[0])  # Determine format dynamically
+
+        pixel_values = []
+
+        for image in images:
+            image = to_numpy_array(image)  # Ensure NumPy format
+
+            if do_normalize:
+                image = self.unnormalize(
+                    image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format
+                )
+
+            if do_rescale:
+                image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
+                image = image.clip(0, 255).astype(np.uint8)
+
+            if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
+                image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
+                image = PIL.Image.fromarray(image)
+
+            pixel_values.append(image)
+
+        data = {"pixel_values": pixel_values}
+        return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
+
+        return BatchFeature(data=data, tensor_type=return_tensors)
+
+    def unnormalize(
+        self,
+        image: np.array,
+        image_mean: Union[float, Iterable[float]],
+        image_std: Union[float, Iterable[float]],
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+    ) -> np.array:
+        """
+        Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`.
+        image = (image * image_std) + image_mean
+        Args:
+            image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`):
+                Batch of pixel values to postprocess.
+            image_mean (`float` or `Iterable[float]`):
+                The mean to use for unnormalization.
+            image_std (`float` or `Iterable[float]`):
+                The standard deviation to use for unnormalization.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+        """
+        num_channels = 3
+
+        if isinstance(image_mean, Iterable):
+            if len(image_mean) != num_channels:
+                raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}")
+        else:
+            image_mean = [image_mean] * num_channels
+
+        if isinstance(image_std, Iterable):
+            if len(image_std) != num_channels:
+                raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}")
+        else:
+            image_std = [image_std] * num_channels
+
+        rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std))
+        rev_image_std = tuple(1 / std for std in image_std)
+        image = self.normalize(
+            image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format
+        )
+        return image
+
+
+__all__ = ["JanusImageProcessor"]
diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py
new file mode 100644
index 00000000000..a4c7937bc45
--- /dev/null
+++ b/src/transformers/models/janus/modeling_janus.py
@@ -0,0 +1,1660 @@
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+#           This file was automatically generated from src/transformers/models/janus/modular_janus.py.
+#               Do NOT edit this file manually as any edits will be overwritten by the generation of
+#             the file from the modular. If any change should be done, please apply the change to the
+#                          modular_janus.py file directly. One of our CI enforces this.
+#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
+from ...generation.utils import GenerateDecoderOnlyOutput
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    can_return_tuple,
+    is_torch_available,
+    logging,
+    replace_return_docstrings,
+    torch_int,
+)
+from ..auto import AutoModel
+from .configuration_janus import JanusConfig, JanusVisionConfig, JanusVQVAEConfig
+
+
+if is_torch_available():
+    import torch.nn.functional as F
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "JanusConfig"
+
+
+JANUS_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`JanusConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+    "The bare Janus Model outputting raw hidden-states without any specific head on top.",
+    JANUS_START_DOCSTRING,
+)
+class JanusPreTrainedModel(PreTrainedModel):
+    config_class = JanusConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["LlamaDecoderLayer"]
+    _skip_keys_device_placement = ["past_key_values", "causal_mask"]
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+    _supports_quantized_cache = True
+    _supports_cache_class = True
+    _supports_static_cache = True
+    _supports_param_buffer_assignment = False
+
+    def _init_weights(self, module):
+        std = (
+            self.config.vision_config.initializer_range
+            if hasattr(self.config, "vision_config")
+            else self.config.initializer_range
+        )
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+@dataclass
+class JanusVQVAEOutput(ModelOutput):
+    """
+    Base class for Janus VQ-VAE mode model outputs.
+    Args:
+        decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+            Reconstructed pixel values after encoding and decoding the input.
+        embedding_loss (`torch.FloatTensor`):
+            Embedding loss.
+    """
+
+    decoded_pixel_values: Optional[torch.FloatTensor] = None
+    embedding_loss: torch.FloatTensor = None
+
+
+@dataclass
+class JanusBaseModelOutputWithPast(ModelOutput):
+    """
+    Base class for Janus model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+            encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+            Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+            sequence_length, hidden_size)`.
+
+            image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+    attentions: Optional[Tuple[torch.FloatTensor]] = None
+    image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class JanusCausalLMOutputWithPast(ModelOutput):
+    """
+    Base class for Janus causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
+            Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
+            sequence_length, hidden_size)`.
+
+            image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[List[torch.FloatTensor]] = None
+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+    attentions: Optional[Tuple[torch.FloatTensor]] = None
+    image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class JanusVisionEmbeddings(nn.Module):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.image_size = config.image_size
+        self.patch_size = config.patch_size
+
+        self.patch_embedding = nn.Conv2d(
+            in_channels=config.num_channels,
+            out_channels=self.embed_dim,
+            kernel_size=self.patch_size,
+            stride=self.patch_size,
+            padding="valid",
+        )
+
+        self.num_patches = (self.image_size // self.patch_size) ** 2
+        self.num_positions = self.num_patches
+        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+        images. This method is also adapted to support torch.jit tracing and no class embeddings.
+
+        Adapted from:
+        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+        """
+
+        num_patches = embeddings.shape[1]
+        num_positions = self.position_embedding.weight.shape[0]
+
+        # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+        if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+            return self.position_embedding(self.position_ids)
+
+        patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
+
+        dim = embeddings.shape[-1]
+
+        new_height = height // self.patch_size
+        new_width = width // self.patch_size
+
+        sqrt_num_positions = torch_int(num_positions**0.5)
+        patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            size=(new_height, new_width),
+            mode="bicubic",
+            align_corners=False,
+        )
+
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return patch_pos_embed
+
+    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+        _, _, height, width = pixel_values.shape
+        target_dtype = self.patch_embedding.weight.dtype
+        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
+        embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+        if interpolate_pos_encoding:
+            pos_embeds = self.interpolate_pos_encoding(embeddings, height, width)
+        else:
+            pos_embeds = self.position_embedding(self.position_ids)
+
+        embeddings = embeddings + pos_embeds
+
+        return embeddings
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+    module: nn.Module,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attention_mask: Optional[torch.Tensor],
+    scaling: float,
+    dropout: float = 0.0,
+    **kwargs,
+):
+    key_states = repeat_kv(key, module.num_key_value_groups)
+    value_states = repeat_kv(value, module.num_key_value_groups)
+
+    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+    if attention_mask is not None:
+        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+        attn_weights = attn_weights + causal_mask
+
+    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+    attn_output = torch.matmul(attn_weights, value_states)
+    attn_output = attn_output.transpose(1, 2).contiguous()
+
+    return attn_output, attn_weights
+
+
+class JanusVisionAttention(nn.Module):
+    """Attention Class for Janus Vision Encoder"""
+
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.embed_dim // self.num_heads
+        if self.head_dim * self.num_heads != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+                f" {self.num_heads})."
+            )
+        self.scale = self.head_dim**-0.5
+        self.attention_dropout = config.attention_dropout
+        proj_dropout = config.projection_dropout
+        qk_norm = config.use_qk_norm
+
+        # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
+        self.num_key_value_groups = 1
+
+        self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
+        self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
+
+        self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
+        self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[torch.Tensor] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ):
+        batch_size, seq_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+        query_states = self.q_norm(query_states)
+
+        key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
+        key_states = self.k_norm(key_states)
+
+        query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+                logger.warning_once(
+                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+                )
+            else:
+                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            dropout=0.0 if not self.training else self.attention_dropout,
+            scaling=self.scale,
+            is_causal=False,
+            **kwargs,
+        )
+        attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
+
+        output = self.projection_layer(attn_output)
+        output = self.projection_dropout(output)
+
+        outputs = (output, attn_weights) if output_attentions else (output, None)
+        return outputs
+
+
+class JanusVisionMLP(nn.Module):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+        self.config = config
+        self.intermediate_size = int(config.hidden_size * config.mlp_ratio)
+        self.activation_fn = ACT2FN[config.hidden_act]  # Gelu act
+        self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
+        self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
+        self.dropout1 = nn.Dropout(config.hidden_dropout_rate)
+        self.dropout2 = nn.Dropout(config.hidden_dropout_rate)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.dropout1(hidden_states)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = self.dropout2(hidden_states)
+        return hidden_states
+
+
+class JanusVisionEncoderLayer(nn.Module):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+        self.embed_dim = config.hidden_size
+        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.self_attn = JanusVisionAttention(config)
+        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.mlp = JanusVisionMLP(config)
+        self.config = config
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`):
+                Input to the layer of shape `(batch, seq_len, embed_dim)`.
+            attention_mask (`torch.FloatTensor`):
+                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+            output_attentions (`bool`, *optional*, defaults to `False`):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+
+        hidden_states = self.layer_norm1(hidden_states)
+        hidden_states, attn_weights = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.layer_norm2(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+
+class JanusVisionEncoder(nn.Module):
+    """
+    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+    [`JanusVisionEncoderLayer`].
+
+    Args:
+        config: JanusVisionConfig
+    """
+
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+        self.config = config
+        self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.gradient_checkpointing = False
+
+    # Ignore copy
+    @can_return_tuple
+    def forward(
+        self,
+        inputs_embeds,
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+    ) -> BaseModelOutput:
+        r"""
+        Args:
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        hidden_states = inputs_embeds
+        for encoder_layer in self.layers:
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    encoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    output_attentions,
+                )
+            else:
+                layer_outputs = encoder_layer(
+                    hidden_states,
+                    attention_mask,
+                    output_attentions=output_attentions,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        return BaseModelOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=encoder_states,
+            attentions=all_attentions,
+        )
+
+
+JANUS_VISION_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Pixel values can be obtained using [`JanusProcessor`]. See [`JanusProcessor.__call__`] for
+            details.
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+            Whether to interpolate the pre-trained position encodings.
+"""
+
+
+class JanusVisionModel(JanusPreTrainedModel):
+    main_input_name = "pixel_values"
+    config_class = JanusVisionConfig
+
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__(config)
+        self.config = config
+        embed_dim = config.hidden_size
+
+        self.embeddings = JanusVisionEmbeddings(config)
+        self.encoder = JanusVisionEncoder(config)
+        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(JANUS_VISION_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=JanusVisionConfig)
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        interpolate_pos_encoding: bool = False,
+    ) -> Union[Tuple, BaseModelOutputWithPooling]:
+        r"""
+        Returns:
+
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+        encoder_outputs = self.encoder(
+            inputs_embeds=hidden_states,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        last_hidden_state = encoder_outputs[0]
+        last_hidden_state = self.post_layernorm(last_hidden_state)
+
+        pooled_output = last_hidden_state[:, 0, :]
+        pooled_output = self.post_layernorm(pooled_output)
+
+        if not return_dict:
+            return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPooling(
+            last_hidden_state=last_hidden_state,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+        )
+
+    def get_input_embeddings(self):
+        return self.embeddings
+
+
+class JanusVisionAlignerMLP(nn.Module):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+
+        self.fc1 = nn.Linear(config.hidden_size, config.projection_dim)
+        self.hidden_layers = nn.ModuleList(
+            [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)]
+        )
+        self.activation_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, hidden_states):
+        hidden_states = self.fc1(hidden_states)
+        for layer in self.hidden_layers:
+            hidden_states = self.activation_fn(hidden_states)
+            hidden_states = layer(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEVectorQuantizer(nn.Module):
+    """
+    A module for vector quantization using learned embedding vectors.
+
+    This module implements the quantization process similar to te one described in
+    the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
+    input vectors into discrete codebook vectors, which are learned during training.
+    Current implementation improves over previous ones by avoiding costly matrix multiplications
+    and allowing for post-hoc remapping of indices.
+    """
+
+    def __init__(self, config: JanusVQVAEConfig):
+        super().__init__()
+        self.num_embeddings = config.num_embeddings
+        self.embedding_dim = config.embed_dim
+        self.beta = getattr(config, "beta", 0.25)
+
+        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
+        self.quant_state_dims = [config.num_patches] * 2
+
+    def forward(self, hidden_state: torch.Tensor):
+        hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
+        hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
+
+        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+        distances = (
+            torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+            + torch.sum(self.embedding.weight**2, dim=1)
+            - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
+        )
+
+        min_encoding_indices = torch.argmin(distances, dim=1)
+        hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
+
+        # compute loss for embedding
+        loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
+            (hidden_state_quant - hidden_state.detach()) ** 2
+        )
+
+        # preserve gradients
+        hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
+
+        # reshape back to match original input shape
+        hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
+
+        return hidden_state_quant, loss, min_encoding_indices
+
+    def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
+        batch_size = image_tokens.shape[0]
+        emb_dim: int = self.embedding.weight.shape[-1]
+
+        # get quantized latent vectors
+        hidden_state_quant = self.embedding(image_tokens)
+        # l2 normalization on the last dimension
+        hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1)
+
+        # reshape back to match original input shape
+        hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim))
+        hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
+
+        return hidden_state_quant
+
+
+class JanusVQVAEResnetBlock(nn.Module):
+    def __init__(
+        self,
+        config,
+        in_channels,
+        out_channels=None,
+        conv_shortcut=False,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = in_channels if out_channels is None else out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+        self.dropout = torch.nn.Dropout(config.dropout)
+        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+            else:
+                self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, hidden_states):
+        residual = hidden_states
+        hidden_states = self.norm1(hidden_states)
+        hidden_states *= torch.sigmoid(hidden_states)
+        hidden_states = self.conv1(hidden_states)
+
+        hidden_states = self.norm2(hidden_states)
+        hidden_states *= torch.sigmoid(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.conv2(hidden_states)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                residual = self.conv_shortcut(residual)
+            else:
+                residual = self.nin_shortcut(residual)
+
+        return residual + hidden_states
+
+
+class JanusVQVAEAttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+        self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+        self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+        self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+    def forward(self, hidden_states):
+        residual = hidden_states
+        hidden_states = self.norm(hidden_states)
+        query_states = self.q(hidden_states)
+        key_states = self.k(hidden_states)
+        value_states = self.v(hidden_states)
+
+        # compute attention
+        batch_size, channels, height, width = query_states.shape
+        query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
+        key_states = key_states.reshape(batch_size, channels, height * width)
+        attn_weights = torch.bmm(query_states, key_states)
+        attn_weights = attn_weights * (int(channels) ** (-0.5))
+        attn_weights = F.softmax(attn_weights, dim=2)
+
+        # attend to values
+        value_states = value_states.reshape(batch_size, channels, height * width)
+        attn_weights = attn_weights.permute(0, 2, 1)
+        attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
+
+        attn_output = self.proj_out(attn_output)
+        return residual + attn_output
+
+
+class JanusVQVAEConvDownsample(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+    def forward(self, hidden_states):
+        # no asymmetric padding in torch conv, must do it ourselves
+        hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
+        hidden_states = self.conv(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEConvUpsample(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+    def forward(self, hidden_states):
+        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+        hidden_states = self.conv(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEMidBlock(nn.Module):
+    def __init__(self, config: JanusVQVAEConfig, channels: int):
+        super().__init__()
+        self.block_1 = JanusVQVAEResnetBlock(
+            config=config,
+            in_channels=channels,
+            out_channels=channels,
+        )
+        self.attn_1 = JanusVQVAEAttnBlock(channels)
+        self.block_2 = JanusVQVAEResnetBlock(
+            config=config,
+            in_channels=channels,
+            out_channels=channels,
+        )
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.block_1(hidden_states)
+        hidden_states = self.attn_1(hidden_states)
+        hidden_states = self.block_2(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        self.num_resolutions = len(config.channel_multiplier)
+        self.num_res_blocks = config.num_res_blocks
+        base_channels = config.base_channels
+        in_channels = config.in_channels
+        double_latent = config.double_latent
+        latent_channels = config.latent_channels
+        channel_multiplier = config.channel_multiplier
+
+        self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
+
+        in_channel_multiplier = (1,) + tuple(channel_multiplier)
+        self.in_channel_multiplier = in_channel_multiplier
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = base_channels * in_channel_multiplier[i_level]
+            block_out = base_channels * channel_multiplier[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(
+                    JanusVQVAEResnetBlock(
+                        config=config,
+                        in_channels=block_in,
+                        out_channels=block_out,
+                    )
+                )
+                block_in = block_out
+                if i_level == self.num_resolutions - 1:
+                    attn.append(JanusVQVAEAttnBlock(block_in))
+
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions - 1:
+                down.downsample = JanusVQVAEConvDownsample(block_in)
+            self.down.append(down)
+
+        self.mid = JanusVQVAEMidBlock(config, block_in)
+
+        self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+        self.conv_out = torch.nn.Conv2d(
+            block_in,
+            2 * latent_channels if double_latent else latent_channels,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+    def forward(self, pixel_values: torch.LongTensor):
+        # downsampling
+        hidden_states = [self.conv_in(pixel_values)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                hidden_state = self.down[i_level].block[i_block](
+                    hidden_states[-1],
+                )
+                if len(self.down[i_level].attn) > 0:
+                    hidden_state = self.down[i_level].attn[i_block](hidden_state)
+                hidden_states.append(hidden_state)
+            if i_level != self.num_resolutions - 1:
+                hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
+
+        # middle
+        last_hidden_state = hidden_states[-1]
+        last_hidden_state = self.mid(last_hidden_state)
+
+        # end
+        last_hidden_state = self.norm_out(last_hidden_state)
+        last_hidden_state *= torch.sigmoid(last_hidden_state)
+        last_hidden_state = self.conv_out(last_hidden_state)
+        return last_hidden_state
+
+
+class JanusVQVAEDecoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        self.num_resolutions = len(config.channel_multiplier)
+        self.num_res_blocks = config.num_res_blocks
+        base_channels = config.base_channels
+        latent_channels = config.latent_channels
+        out_channels = config.out_channels
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1]
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+        # middle
+        self.mid = JanusVQVAEMidBlock(config, block_in)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = base_channels * config.channel_multiplier[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                block.append(
+                    JanusVQVAEResnetBlock(
+                        config=config,
+                        in_channels=block_in,
+                        out_channels=block_out,
+                    )
+                )
+                block_in = block_out
+                if i_level == self.num_resolutions - 1:
+                    attn.append(JanusVQVAEAttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = JanusVQVAEConvUpsample(block_in)
+            self.up.append(up)
+
+        # end
+        self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+        self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
+
+    def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor:
+        hidden_state = self.conv_in(hidden_state)
+
+        # middle
+        hidden_state = self.mid(hidden_state)
+
+        # upsampling
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks + 1):
+                hidden_state = self.up[i_level].block[i_block](hidden_state)
+                if len(self.up[i_level].attn) > 0:
+                    hidden_state = self.up[i_level].attn[i_block](hidden_state)
+            if i_level != self.num_resolutions - 1:
+                hidden_state = self.up[i_level].upsample(hidden_state)
+
+        hidden_state = self.norm_out(hidden_state)
+        hidden_state *= torch.sigmoid(hidden_state)
+        hidden_state = self.conv_out(hidden_state)
+        return hidden_state
+
+
+JANUS_VQ_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`JanusVQVAEConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+    """The VQ-VAE model used in Janus for encoding/decoding images into discrete tokens.
+    This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
+    [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
+    """,
+    JANUS_VQ_START_DOCSTRING,
+)
+class JanusVQVAE(JanusPreTrainedModel):
+    """Vision Transformer-based VQ-VAE model for encoding and decoding pixel values."""
+
+    config_class = JanusVQVAEConfig
+
+    _no_split_modules = [
+        "JanusVQVAEAttnBlock",
+        "JanusVQVAEResnetBlock",
+        "JanusVQVAEVectorQuantizer",
+    ]
+    main_input_name = "pixel_values"
+
+    def __init__(self, config: JanusVQVAEConfig):
+        super().__init__(config)
+
+        self.encoder = JanusVQVAEEncoder(config)
+        self.quantize = JanusVQVAEVectorQuantizer(config)
+        self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
+        self.eval()  # Janus's VQ model is frozen
+        self.decoder = JanusVQVAEDecoder(config)
+        self.gradient_checkpointing = False
+
+        # Initialize the VQVAE model.
+        self.post_init()
+
+    def encode(self, pixel_values: torch.LongTensor):
+        hidden_states = self.encoder(pixel_values)
+        hidden_states = self.quant_conv(hidden_states)
+        quant, emb_loss, indices = self.quantize(hidden_states)
+        return quant, emb_loss, indices
+
+    def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
+        """
+        Decodes quantized token IDs into pixel values.
+        Args:
+            image_tokens (torch.LongTensor): Batch of token IDs.
+        Returns:
+            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+                Pixel values decoded from the token IDs.
+        """
+        if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]:
+            raise ValueError(
+                f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, "
+                f"but got shape `{image_tokens.shape}`."
+            )
+        codebook_entry = self.quantize.get_codebook_entry(image_tokens)
+        hidden_states = self.post_quant_conv(codebook_entry)
+        pixel_values = self.decoder(hidden_states)
+        return pixel_values
+
+    @can_return_tuple
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor,
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+        """
+        Encodes pixel values into quantized tokens and decodes them back.
+        Args:
+            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+                The tensors corresponding to the input images.
+        Returns:
+            decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+                Reconstructed pixel values after encoding and decoding the input.
+            embedding_loss (`torch.FloatTensor`): Embedding loss.
+        """
+
+        batch_size = pixel_values.shape[0]
+        quant, embedding_loss, indices = self.encode(pixel_values)
+        decoded_pixel_values = self.decode(indices.view(batch_size, -1))
+        output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
+
+        return output
+
+
+class JanusVQVAEAlignerMLP(nn.Module):
+    def __init__(self, config: JanusVQVAEConfig):
+        super().__init__()
+
+        self.fc1 = nn.Linear(config.embed_dim, config.projection_dim)
+        self.hidden_layers = nn.ModuleList(
+            [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)]
+        )
+        self.activation_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, hidden_states):
+        hidden_states = self.fc1(hidden_states)
+        for layer in self.hidden_layers:
+            hidden_states = self.activation_fn(hidden_states)
+            hidden_states = layer(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEHead(nn.Module):
+    """Head used for sampling tokens in image generation, replacing the usual lm head."""
+
+    def __init__(self, config: JanusVQVAEConfig):
+        super().__init__()
+        self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim)
+        self.activation_fn = ACT2FN[config.hidden_act]
+        self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.tensor:
+        hidden_states = self.proj_out(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.vision_head(hidden_states)
+        return hidden_states
+
+
+JANUS_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+            The tensors corresponding to the input images. Pixel values can be obtained using
+            [`AutoImageProcessor`].
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+            information on the default strategy.
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+            the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+    """The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.""",
+    JANUS_START_DOCSTRING,
+)
+class JanusModel(JanusPreTrainedModel):
+    def __init__(self, config: JanusConfig):
+        super().__init__(config)
+        self.config = config
+        # This is necessary for backward compatibility, see SiglipModel initialization
+        self.vision_model = JanusVisionModel._from_config(config.vision_config)
+        self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
+
+        self.vqmodel = JanusVQVAE._from_config(config.vq_config)
+
+        # Below generation_* modules are used for Image generation.
+        # Embeddings used for image generation, instead of Janus vision embeddings.
+        self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim)
+        self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
+        self.generation_head = JanusVQVAEHead(self.vqmodel.config)
+
+        self.language_model = AutoModel.from_config(config=config.text_config)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing.
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.language_model.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.language_model.set_input_embeddings(value)
+
+    def get_image_features(self, pixel_values):
+        image_embeds = self.vision_model(pixel_values)
+        image_embeds = self.aligner(image_embeds.last_hidden_state)
+        return image_embeds
+
+    @can_return_tuple
+    @add_start_docstrings_to_model_forward(JANUS_INPUTS_DOCSTRING)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        pixel_values: torch.FloatTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        logits_to_keep: Union[int, torch.Tensor] = 0,
+        **kwargs,
+    ):
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError(
+                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+            )
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        if pixel_values is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+            )
+
+        if inputs_embeds is None:
+            inputs_embeds = self.get_input_embeddings()(input_ids)
+
+        if pixel_values is not None:
+            image_embeds = self.get_image_features(pixel_values)
+            image_attention_mask = input_ids == self.config.image_token_index
+
+            embed_dim = inputs_embeds.shape[-1]
+            image_features = image_embeds.reshape(-1, embed_dim)
+            image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
+
+            image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+            inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
+
+        lm_output = self.language_model(
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            cache_position=cache_position,
+            logits_to_keep=logits_to_keep,
+            **kwargs,
+        )
+
+        output = JanusBaseModelOutputWithPast(
+            last_hidden_state=lm_output.last_hidden_state,
+            past_key_values=lm_output.past_key_values,
+            hidden_states=lm_output.hidden_states,
+            attentions=lm_output.attentions,
+            image_hidden_states=image_embeds if pixel_values is not None else None,
+        )
+
+        return output
+
+
+class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
+    _supports_static_cache = True
+
+    def __init__(self, config: JanusConfig):
+        super().__init__(config)
+        self.config = config
+        self.model = JanusModel(config)
+        self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing.
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.language_model.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.model.language_model.set_input_embeddings(value)
+
+    def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor:
+        hidden_state = self.model.generation_embeddings(inputs)
+        hidden_state = self.model.generation_aligner(hidden_state)
+        return hidden_state
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    @can_return_tuple
+    @add_start_docstrings_to_model_forward(JANUS_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=JanusCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        pixel_values: torch.FloatTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        logits_to_keep: Union[int, torch.Tensor] = 0,
+        **kwargs,
+    ):
+        r"""
+        Args:
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+            logits_to_keep (`int` or `torch.Tensor`, *optional*):
+                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+                This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+        Returns:
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        outputs = self.model(
+            input_ids=input_ids,
+            pixel_values=pixel_values,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            cache_position=cache_position,
+            **kwargs,
+        )
+        hidden_states = outputs.last_hidden_state
+        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+        logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+        output = JanusCausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            image_hidden_states=outputs.image_hidden_states,
+        )
+        return output
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        pixel_values=None,
+        past_key_values=None,
+        attention_mask=None,
+        inputs_embeds=None,
+        cache_position=None,
+        logits_to_keep=None,
+        **kwargs,
+    ):
+        # Overwritten -- extra custom processing
+
+        model_inputs = super().prepare_inputs_for_generation(
+            input_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            cache_position=cache_position,
+            logits_to_keep=logits_to_keep,
+            **kwargs,
+        )
+
+        # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+        # Otherwise we need pixel values to be passed to model
+        if cache_position[0] == 0:
+            model_inputs["pixel_values"] = pixel_values
+
+        return model_inputs
+
+    def decode_image_tokens(self, image_tokens: torch.Tensor):
+        """
+        Decodes generated image tokens from language model to continuous pixel values
+        with VQGAN module via upsampling.
+        Args:
+            image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
+                The tensors corresponding to the input images.
+        """
+        decoded_image = self.model.vqmodel.decode(image_tokens)
+        decoded_image = decoded_image.permute(0, 2, 3, 1)
+        return decoded_image
+
+    @torch.no_grad
+    def generate(
+        self,
+        inputs: torch.Tensor = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        **kwargs,
+    ):
+        # 1. Handle generation config and model kwargs
+        generation_config = kwargs.pop("generation_config", self.generation_config)
+        generation_config = copy.deepcopy(generation_config)
+
+        # Default to "text" generation if mode isn't provided
+        generation_mode = kwargs.pop("generation_mode", "text")
+        if generation_mode == "text":
+            # Set guidance_scale=None to prevent running UnbatchedCFG processor.
+            return super().generate(
+                inputs=inputs,
+                attention_mask=attention_mask,
+                generation_config=generation_config,
+                guidance_scale=None,
+                **kwargs,
+            )
+
+        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
+
+        # Validate generation mode
+        if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+            raise ValueError(
+                "Got incompatible mode for Image Generation, should be one of greedy or sampling. "
+                "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+            )
+
+        # Validate the configuration and model kwargs
+        generation_config.validate()
+        self._validate_model_kwargs(model_kwargs.copy())
+
+        # 2. Initialize logit processors
+        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+
+        # Set `use_cache=True` as we will be using input embeds for generation.
+        model_kwargs["use_cache"] = True
+
+        if generation_config.guidance_scale is None:
+            logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.")
+            generation_config.guidance_scale = 5
+        model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+        # 3. Prepare model inputs
+        input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
+            inputs, generation_config.bos_token_id, model_kwargs
+        )
+        dtype, device = input_ids.dtype, input_ids.device
+
+        if len(input_ids.shape) != 2:
+            raise ValueError(
+                f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}"
+                "Passing `inputs embeds` is not supported currently."
+            )
+
+        # Prepare special tokens which will be used generate internally.
+        kwargs_has_attention_mask = attention_mask is not None
+        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
+
+        # 4. Add CFG processor along with user passed logit processor.
+        if generation_config.guidance_scale and generation_config.guidance_scale > 1:
+            logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+            generation_config.guidance_scale = None  # Reset to prevent processor duplication.
+
+        # 5. Prepare logits processor
+        logits_processor = self._get_logits_processor(
+            generation_config=generation_config,
+            input_ids_seq_length=input_ids.shape[1],
+            encoder_input_ids=input_ids,
+            prefix_allowed_tokens_fn=None,
+            logits_processor=logits_processor,
+            device=device,
+        )
+
+        # 6. Expand inputs for multiple image generations per prompt.
+        input_ids, model_kwargs = self._expand_inputs_for_generation(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            expand_size=generation_config.num_return_sequences,
+            **model_kwargs,
+        )
+
+        # 7. Prepare input and model caches
+        num_image_tokens = self.model.vision_model.config.num_image_tokens
+        batch_size, seq_len = input_ids.shape
+
+        input_tokens = input_ids.repeat(2, 1)  # Double batch size for conditional/unconditional logits
+        attention_mask = model_kwargs.pop("attention_mask", None)
+        attention_mask = attention_mask.repeat(2, 1)
+        model_kwargs["attention_mask"] = attention_mask
+
+        # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
+        mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
+            input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
+        )
+        input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)
+
+        inputs_embeds = self.get_input_embeddings()(input_tokens)
+
+        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+
+        if model_kwargs.get("past_key_values", None) is None:
+            # Prepare cache if not provided.
+            model_kwargs["past_key_values"] = self._get_cache(
+                cache_implementation=generation_config.cache_implementation or "static",
+                # batch_size should account for both conditional/unconditional input; hence multiplied by 2.
+                batch_size=batch_size * 2,
+                # we should have at least a cache len of seq_len + num_image_tokens.
+                max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
+                device=device,
+                model_kwargs=model_kwargs,
+            )
+
+        # Placeholder for generated tokens.
+        generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device)
+
+        # 8. init attention / hidden states / scores tuples
+        output_attentions = generation_config.output_attentions
+        output_hidden_states = generation_config.output_hidden_states
+        output_scores = generation_config.output_scores
+        output_logits = generation_config.output_logits
+        return_dict_in_generate = generation_config.return_dict_in_generate
+
+        raw_scores = () if (return_dict_in_generate and output_scores) else None
+        raw_logits = () if (return_dict_in_generate and output_logits) else None
+        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+
+        for i in range(num_image_tokens):
+            model_inputs = self.prepare_inputs_for_generation(
+                inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
+            )
+
+            model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
+            model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)
+
+            outputs = self.model.language_model(
+                **model_inputs,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+            )
+
+            # Update model_kwargs like cache_position for next generation.
+            model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
+            hidden_state = outputs.last_hidden_state[:, -1, :].clone()
+
+            # Generate scores using the generation head (Not using above defined LM Head)
+            scores = self.model.generation_head(hidden_state)
+            next_token_scores = logits_processor(input_ids, scores)
+
+            # Sample next token.
+            if generation_config.do_sample:
+                probs = torch.softmax(next_token_scores, dim=-1)
+                next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
+            else:
+                next_token = torch.argmax(next_token_scores, dim=-1)
+
+            generated_tokens[:, i] = next_token
+
+            # Prepare embeddings for the next step.
+            next_token = torch.cat([next_token, next_token])
+            next_token = next_token.unsqueeze(-1)
+
+            inputs_embeds = self.prepare_embeddings_for_image_generation(next_token)
+
+        if return_dict_in_generate:
+            if output_scores:
+                raw_scores += (scores,)
+            if output_logits:
+                raw_logits += (hidden_state.float(),)
+            if output_attentions:
+                decoder_attentions += outputs.attentions
+            if output_hidden_states:
+                decoder_hidden_states += outputs.hidden_states
+
+        if return_dict_in_generate:
+            return GenerateDecoderOnlyOutput(
+                sequences=generated_tokens,
+                scores=scores,
+                logits=raw_logits,
+                attentions=decoder_attentions,
+                hidden_states=decoder_hidden_states,
+                past_key_values=outputs.past_key_values,
+            )
+        else:
+            return generated_tokens
+
+
+__all__ = ["JanusPreTrainedModel", "JanusForConditionalGeneration", "JanusModel", "JanusVQVAE", "JanusVisionModel"]
diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py
new file mode 100644
index 00000000000..03e3a05a27a
--- /dev/null
+++ b/src/transformers/models/janus/modular_janus.py
@@ -0,0 +1,1770 @@
+# coding=utf-8
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from dataclasses import dataclass
+from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from transformers.models.blip.image_processing_blip import BlipImageProcessor
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import (
+    ClassifierFreeGuidanceLogitsProcessor,
+    GenerationMixin,
+    GenerationMode,
+    LogitsProcessorList,
+)
+from ...generation.utils import GenerateDecoderOnlyOutput
+from ...image_processing_utils import BatchFeature, get_size_dict
+from ...image_transforms import (
+    resize,
+    to_channel_dimension_format,
+)
+from ...image_utils import (
+    ChannelDimension,
+    ImageInput,
+    PILImageResampling,
+    get_image_size,
+    infer_channel_dimension_format,
+    make_list_of_images,
+    to_numpy_array,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import ModelOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    can_return_tuple,
+    is_torch_available,
+    is_vision_available,
+    logging,
+    replace_return_docstrings,
+)
+from ..auto import AutoModel
+from ..blip_2.modeling_blip_2 import Blip2VisionModel
+from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
+from ..chameleon.modeling_chameleon import (
+    ChameleonVQVAE,
+    ChameleonVQVAEEncoder,
+    ChameleonVQVAEEncoderAttnBlock,
+    ChameleonVQVAEEncoderConvDownsample,
+    ChameleonVQVAEEncoderResnetBlock,
+    ChameleonVQVAEVectorQuantizer,
+)
+from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast
+from ..llama.modeling_llama import eager_attention_forward
+from ..siglip.configuration_siglip import SiglipVisionConfig
+from ..siglip.modeling_siglip import (
+    SiglipEncoder,
+    SiglipEncoderLayer,
+    SiglipVisionEmbeddings,
+)
+
+
+if is_torch_available():
+    import torch
+    import torch.nn as nn
+    import torch.nn.functional as F
+    import torch.utils.checkpoint
+
+if is_vision_available():
+    import PIL
+
+
+from ...configuration_utils import PretrainedConfig
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "JanusConfig"
+
+
+class JanusVisionConfig(SiglipVisionConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`JanusVisionModel`]. It is used to instantiate a
+    `JanusVisionModel` according to the specified arguments, defining the model architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+    Args:
+        hidden_size (`int`, *optional*, defaults to 1024):
+            Dimensionality of the encoder layers and the pooler layer.
+        num_hidden_layers (`int`, *optional*, defaults to 24):
+            Number of hidden layers in the Transformer encoder.
+        num_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        patch_size (`int`, *optional*, defaults to 16):
+            The size (resolution) of each patch.
+        image_size (`int`, *optional*, defaults to 384):
+            The size (resolution) of each image.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            Dropout probability for attention weights.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+            The epsilon used by the layer normalization layers.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"selu"`, and `"gelu_new"` are supported.
+        mlp_ratio (`float`, *optional*, defaults to 4.0):
+            Ratio of MLP hidden dimensionality to embedding dimensionality.
+        attention_bias (`bool`, *optional*, defaults to `True`):
+            Whether to add a bias to the queries, keys, and values in the attention layers.
+        hidden_dropout_rate (`float`, *optional*, defaults to 0.0):
+            The dropout probability for fully connected layers in the encoder.
+        projection_dim (`int`, *optional*, defaults to 2048):
+            Dimensionality of the MLP projection head.
+        projection_dropout (`float`, *optional*, defaults to 0.0):
+            Dropout probability for the projection layer.
+        use_qk_norm (`bool`, *optional*, defaults to `False`):
+            Whether to normalize the query and key matrices.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated normal initializer for initializing all weight matrices.
+        depth (`int`, *optional*, defaults to 2):
+            Number of hidden layers in the aligner module.
+        num_image_tokens (`int`, *optional*, defaults to 576):
+            Number of image tokens.
+    """
+
+    model_type = "janus_vision_model"
+    base_config_key = "vision_config"
+
+    def __init__(
+        self,
+        hidden_size=1024,
+        num_hidden_layers=24,
+        num_attention_heads=16,
+        num_channels=3,
+        patch_size=16,
+        image_size=384,
+        attention_dropout=0.0,
+        layer_norm_eps=1e-6,
+        hidden_act="gelu",
+        mlp_ratio=4.0,
+        attention_bias=True,
+        hidden_dropout_rate=0.0,
+        projection_dim=2048,
+        projection_dropout=0.0,
+        use_qk_norm=False,
+        initializer_range=0.02,
+        depth=2,
+        num_image_tokens=576,
+        **kwargs,
+    ):
+        super().__init__(
+            hidden_size=hidden_size,
+            num_hidden_layers=num_hidden_layers,
+            num_attention_heads=num_attention_heads,
+            num_channels=num_channels,
+            patch_size=patch_size,
+            image_size=image_size,
+            attention_dropout=attention_dropout,
+            layer_norm_eps=layer_norm_eps,
+            hidden_act=hidden_act,
+            **kwargs,
+        )
+        del self.intermediate_size
+
+        self.mlp_ratio = mlp_ratio
+        self.attention_bias = attention_bias
+        self.hidden_dropout_rate = hidden_dropout_rate
+        self.projection_dim = projection_dim
+        self.projection_dropout = projection_dropout
+        self.use_qk_norm = use_qk_norm
+        self.initializer_range = initializer_range
+        self.depth = depth
+        self.num_image_tokens = num_image_tokens
+
+
+class JanusVQVAEConfig(ChameleonVQVAEConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`JanusVQVAEModel`]. It is used to instantiate a
+    `JanusVQVAEModel` according to the specified arguments, defining the model architecture.
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information. Instantiating a
+    configuration with the defaults will yield a similar configuration to the VQModel of the
+    [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B).
+
+    Args:
+        embed_dim (`int`, *optional*, defaults to 8):
+            Dimensionality of each embedding vector.
+        num_embeddings (`int`, *optional*, defaults to 16384):
+            Number of codebook embeddings.
+        double_latent (`bool`, *optional*, defaults to `False`):
+            Whether to use double z channels.
+        latent_channels (`int`, *optional*, defaults to 256):
+            Number of channels for the latent space.
+        num_patches (`int`, *optional*, defaults to 32):
+            Num of patches the input images can be divided into.
+        in_channels (`int`, *optional*, defaults to 3):
+            Number of input channels.
+        out_channels (`int`, *optional*, defaults to 3):
+            Number of out channels.
+        base_channels (`int`, *optional*, defaults to 128):
+            Base channel count.
+        channel_multiplier (`List[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
+            Channel multipliers for each resolution.
+        num_res_blocks (`int`, *optional*, defaults to 2):
+            Number of residual blocks.
+        dropout (`float`, *optional*, defaults to 0.0):
+            Dropout rate.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        projection_dim (`int`, *optional*, defaults to 2048):
+            Dimensionality of the MLP projection head.
+        num_hidden_layers (`int`, *optional*, defaults to 2):
+            Number of hidden layers in VAVAE MLP Connecter module.
+        hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"silu"` and `"gelu_new"` are supported.
+        image_token_embed_dim (`int`, *optional*, defaults to 2048):
+            Dimension of image embeddings. It should be same as the dimensionality of text embeddings.
+    """
+
+    def __init__(
+        self,
+        embed_dim: int = 8,
+        num_embeddings: int = 16384,
+        double_latent: bool = False,
+        latent_channels: int = 256,
+        num_patches: int = 32,
+        in_channels: int = 3,
+        out_channels: int = 3,
+        base_channels: int = 128,
+        channel_multiplier: List[int] = [1, 1, 2, 2, 4],
+        num_res_blocks: int = 2,
+        dropout: float = 0.0,
+        initializer_range=0.02,
+        projection_dim=2048,
+        num_hidden_layers=2,
+        hidden_act="gelu",
+        image_token_embed_dim=2048,
+        **kwargs,
+    ):
+        super().__init__(
+            embed_dim=embed_dim,
+            num_embeddings=num_embeddings,
+            double_latent=double_latent,
+            latent_channels=latent_channels,
+            in_channels=in_channels,
+            base_channels=base_channels,
+            channel_multiplier=channel_multiplier,
+            num_res_blocks=num_res_blocks,
+            dropout=dropout,
+            initializer_range=initializer_range,
+            **kwargs,
+        )
+        self.num_patches = num_patches
+        self.out_channels = out_channels
+        self.projection_dim = projection_dim
+        self.num_hidden_layers = num_hidden_layers
+        self.hidden_act = hidden_act
+        self.image_token_embed_dim = image_token_embed_dim
+
+        del self.resolution
+        del self.attn_resolutions
+        del self.attn_type
+
+
+class JanusConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`JanusModel`]. It is used to instantiate an
+    Janus model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the Janus-1B or Janus-7B models.
+
+    e.g. [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B) or
+    [deepseek-community/Janus-Pro-7B](https://huggingface.co/deepseek-community/Janus-Pro-7B)
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
+            The config object or dictionary of the text backbone.
+        vision_config (`Union[AutoConfig, dict]`,  *optional*, defaults to `JanusVisionConfig`):
+            The config object or dictionary of the vision backbone.
+        vq_config (`Union[AutoConfig, dict]`,  *optional*, defaults to `JanusVQVAEConfig`):
+            The config object or dictionary of the VQVAE backbone.
+
+    Example:
+
+    ```python
+    >>> from transformers import JanusForConditionalGeneration, JanusConfig, JanusVisionConfig, JanusVQVAEConfig, LlamaConfig
+
+    >>> # Initializing a Janus vision config
+    >>> vision_config = JanusVisionConfig()
+
+    >>> # Initializing a Llama config
+    >>> text_config = LlamaConfig()
+
+    >>> # Initializing a VQ config
+    >>> vq_config = JanusVQVAEConfig()
+
+    >>> # Initializing a Janus Pro 1B style configuration
+    >>> configuration = JanusConfig(vision_config=vision_config, text_config=text_config, vq_config=vq_config)
+
+    >>> # Initializing a model from the Janus Pro 1B style configuration
+    >>> model = JanusForConditionalGeneration(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "janus"
+    sub_configs = {
+        "text_config": AutoConfig,
+        "vision_config": JanusVisionConfig,
+        "vq_config": JanusVQVAEConfig,
+    }
+
+    def __init__(self, text_config=None, vision_config=None, vq_config=None, **kwargs):
+        if isinstance(text_config, dict):
+            text_config["model_type"] = text_config.get("model_type", "llama")
+            self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+
+        elif text_config is None:
+            logger.info("`text_config` is None. Initializing with default values")
+            self.text_config = CONFIG_MAPPING["llama"]()
+        elif isinstance(text_config, PretrainedConfig):
+            self.text_config = text_config
+        else:
+            raise ValueError(
+                f"Invalid type for `text_config`. Must be either `dict` or `LlamaConfig`."
+                f" Type found: {type(text_config)}"
+            )
+
+        if vision_config is None:
+            logger.info("`vision_config` is None. Initializing with default JanusVisionConfig values")
+            self.vision_config = JanusVisionConfig()
+        elif isinstance(vision_config, dict):
+            self.vision_config = JanusVisionConfig(**vision_config)
+        elif isinstance(vision_config, JanusVisionConfig):
+            self.vision_config = vision_config
+        else:
+            raise ValueError(
+                f"Invalid type for `vision_config`. Must be either `dict` or `JanusVisionConfig`."
+                f" Type found: {type(vision_config)}"
+            )
+
+        if vq_config is None:
+            logger.info("`vq_config` is None. Initializing with default JanusVQVAEConfig values")
+            self.vq_config = JanusVQVAEConfig()
+        elif isinstance(vq_config, dict):
+            self.vq_config = JanusVQVAEConfig(**vq_config)
+        elif isinstance(vq_config, JanusVQVAEConfig):
+            self.vq_config = vq_config
+        else:
+            raise ValueError(
+                f"Invalid type for `vq_config`. Must be either `dict` or `JanusVQVAEConfig`."
+                f" Type found: {type(vq_config)}"
+            )
+
+        # This dimension is required when decoding discrete image tokens to continuous input.
+        self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size
+        # The default is only the index for the 1B model, 7B uses a different one
+        self.image_token_index = kwargs.get("image_token_index", 100581)
+        super().__init__(**kwargs)
+
+
+JANUS_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`JanusConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+    "The bare Janus Model outputting raw hidden-states without any specific head on top.",
+    JANUS_START_DOCSTRING,
+)
+class JanusPreTrainedModel(PreTrainedModel):
+    config_class = JanusConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["LlamaDecoderLayer"]
+    _skip_keys_device_placement = ["past_key_values", "causal_mask"]
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+    _supports_quantized_cache = True
+    _supports_cache_class = True
+    _supports_static_cache = True
+    _supports_param_buffer_assignment = False
+
+    def _init_weights(self, module):
+        std = (
+            self.config.vision_config.initializer_range
+            if hasattr(self.config, "vision_config")
+            else self.config.initializer_range
+        )
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+@dataclass
+class JanusVQVAEOutput(ModelOutput):
+    """
+    Base class for Janus VQ-VAE mode model outputs.
+    Args:
+        decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+            Reconstructed pixel values after encoding and decoding the input.
+        embedding_loss (`torch.FloatTensor`):
+            Embedding loss.
+    """
+
+    decoded_pixel_values: Optional[torch.FloatTensor] = None
+    embedding_loss: torch.FloatTensor = None
+
+
+@dataclass
+class JanusBaseModelOutputWithPast(IdeficsBaseModelOutputWithPast):
+    pass
+
+
+@dataclass
+class JanusCausalLMOutputWithPast(IdeficsCausalLMOutputWithPast):
+    pass
+
+
+class JanusVisionEmbeddings(SiglipVisionEmbeddings):
+    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+        _, _, height, width = pixel_values.shape
+        target_dtype = self.patch_embedding.weight.dtype
+        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
+        embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+        if interpolate_pos_encoding:
+            pos_embeds = self.interpolate_pos_encoding(embeddings, height, width)
+        else:
+            pos_embeds = self.position_embedding(self.position_ids)
+
+        embeddings = embeddings + pos_embeds
+
+        return embeddings
+
+
+class JanusVisionAttention(nn.Module):
+    """Attention Class for Janus Vision Encoder"""
+
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.embed_dim // self.num_heads
+        if self.head_dim * self.num_heads != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+                f" {self.num_heads})."
+            )
+        self.scale = self.head_dim**-0.5
+        self.attention_dropout = config.attention_dropout
+        proj_dropout = config.projection_dropout
+        qk_norm = config.use_qk_norm
+
+        # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
+        self.num_key_value_groups = 1
+
+        self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
+        self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
+
+        self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
+        self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        output_attentions: Optional[torch.Tensor] = None,
+        **kwargs: Unpack[FlashAttentionKwargs],
+    ):
+        batch_size, seq_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+        query_states = self.q_norm(query_states)
+
+        key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
+        key_states = self.k_norm(key_states)
+
+        query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+        attention_interface: Callable = eager_attention_forward
+        if self.config._attn_implementation != "eager":
+            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+                logger.warning_once(
+                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+                )
+            else:
+                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+        attn_output, attn_weights = attention_interface(
+            self,
+            query_states,
+            key_states,
+            value_states,
+            attention_mask,
+            dropout=0.0 if not self.training else self.attention_dropout,
+            scaling=self.scale,
+            is_causal=False,
+            **kwargs,
+        )
+        attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
+
+        output = self.projection_layer(attn_output)
+        output = self.projection_dropout(output)
+
+        outputs = (output, attn_weights) if output_attentions else (output, None)
+        return outputs
+
+
+class JanusVisionMLP(nn.Module):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+        self.config = config
+        self.intermediate_size = int(config.hidden_size * config.mlp_ratio)
+        self.activation_fn = ACT2FN[config.hidden_act]  # Gelu act
+        self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
+        self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
+        self.dropout1 = nn.Dropout(config.hidden_dropout_rate)
+        self.dropout2 = nn.Dropout(config.hidden_dropout_rate)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.dropout1(hidden_states)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = self.dropout2(hidden_states)
+        return hidden_states
+
+
+class JanusVisionEncoderLayer(SiglipEncoderLayer):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+        self.config = config
+        self.embed_dim = config.hidden_size
+        self.self_attn = JanusVisionAttention(config)
+        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+        self.mlp = JanusVisionMLP(config)
+
+
+class JanusVisionEncoder(SiglipEncoder):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__(config)
+        self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+
+
+class JanusVisionModel(Blip2VisionModel):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__(config)
+        self.encoder = JanusVisionEncoder(config)
+
+
+class JanusVisionAlignerMLP(nn.Module):
+    def __init__(self, config: JanusVisionConfig):
+        super().__init__()
+
+        self.fc1 = nn.Linear(config.hidden_size, config.projection_dim)
+        self.hidden_layers = nn.ModuleList(
+            [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)]
+        )
+        self.activation_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, hidden_states):
+        hidden_states = self.fc1(hidden_states)
+        for layer in self.hidden_layers:
+            hidden_states = self.activation_fn(hidden_states)
+            hidden_states = layer(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEVectorQuantizer(ChameleonVQVAEVectorQuantizer):
+    def __init__(self, config: JanusVQVAEConfig):
+        super().__init__(config)
+        self.quant_state_dims = [config.num_patches] * 2
+
+    def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
+        batch_size = image_tokens.shape[0]
+        emb_dim: int = self.embedding.weight.shape[-1]
+
+        # get quantized latent vectors
+        hidden_state_quant = self.embedding(image_tokens)
+        # l2 normalization on the last dimension
+        hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1)
+
+        # reshape back to match original input shape
+        hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim))
+        hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
+
+        return hidden_state_quant
+
+
+class JanusVQVAEResnetBlock(ChameleonVQVAEEncoderResnetBlock):
+    pass
+
+
+class JanusVQVAEAttnBlock(ChameleonVQVAEEncoderAttnBlock):
+    pass
+
+
+class JanusVQVAEConvDownsample(ChameleonVQVAEEncoderConvDownsample):
+    pass
+
+
+class JanusVQVAEConvUpsample(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+    def forward(self, hidden_states):
+        hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+        hidden_states = self.conv(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEMidBlock(nn.Module):
+    def __init__(self, config: JanusVQVAEConfig, channels: int):
+        super().__init__()
+        self.block_1 = JanusVQVAEResnetBlock(
+            config=config,
+            in_channels=channels,
+            out_channels=channels,
+        )
+        self.attn_1 = JanusVQVAEAttnBlock(channels)
+        self.block_2 = JanusVQVAEResnetBlock(
+            config=config,
+            in_channels=channels,
+            out_channels=channels,
+        )
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.block_1(hidden_states)
+        hidden_states = self.attn_1(hidden_states)
+        hidden_states = self.block_2(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEEncoder(ChameleonVQVAEEncoder, nn.Module):
+    def __init__(self, config):
+        nn.Module.__init__()
+
+        self.num_resolutions = len(config.channel_multiplier)
+        self.num_res_blocks = config.num_res_blocks
+        base_channels = config.base_channels
+        in_channels = config.in_channels
+        double_latent = config.double_latent
+        latent_channels = config.latent_channels
+        channel_multiplier = config.channel_multiplier
+
+        self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
+
+        in_channel_multiplier = (1,) + tuple(channel_multiplier)
+        self.in_channel_multiplier = in_channel_multiplier
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = base_channels * in_channel_multiplier[i_level]
+            block_out = base_channels * channel_multiplier[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(
+                    JanusVQVAEResnetBlock(
+                        config=config,
+                        in_channels=block_in,
+                        out_channels=block_out,
+                    )
+                )
+                block_in = block_out
+                if i_level == self.num_resolutions - 1:
+                    attn.append(JanusVQVAEAttnBlock(block_in))
+
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions - 1:
+                down.downsample = JanusVQVAEConvDownsample(block_in)
+            self.down.append(down)
+
+        self.mid = JanusVQVAEMidBlock(config, block_in)
+
+        self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+        self.conv_out = torch.nn.Conv2d(
+            block_in,
+            2 * latent_channels if double_latent else latent_channels,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+    def forward(self, pixel_values: torch.LongTensor):
+        # downsampling
+        hidden_states = [self.conv_in(pixel_values)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                hidden_state = self.down[i_level].block[i_block](
+                    hidden_states[-1],
+                )
+                if len(self.down[i_level].attn) > 0:
+                    hidden_state = self.down[i_level].attn[i_block](hidden_state)
+                hidden_states.append(hidden_state)
+            if i_level != self.num_resolutions - 1:
+                hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
+
+        # middle
+        last_hidden_state = hidden_states[-1]
+        last_hidden_state = self.mid(last_hidden_state)
+
+        # end
+        last_hidden_state = self.norm_out(last_hidden_state)
+        last_hidden_state *= torch.sigmoid(last_hidden_state)
+        last_hidden_state = self.conv_out(last_hidden_state)
+        return last_hidden_state
+
+
+class JanusVQVAEDecoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+
+        self.num_resolutions = len(config.channel_multiplier)
+        self.num_res_blocks = config.num_res_blocks
+        base_channels = config.base_channels
+        latent_channels = config.latent_channels
+        out_channels = config.out_channels
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1]
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+        # middle
+        self.mid = JanusVQVAEMidBlock(config, block_in)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = base_channels * config.channel_multiplier[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                block.append(
+                    JanusVQVAEResnetBlock(
+                        config=config,
+                        in_channels=block_in,
+                        out_channels=block_out,
+                    )
+                )
+                block_in = block_out
+                if i_level == self.num_resolutions - 1:
+                    attn.append(JanusVQVAEAttnBlock(block_in))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = JanusVQVAEConvUpsample(block_in)
+            self.up.append(up)
+
+        # end
+        self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+        self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
+
+    def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor:
+        hidden_state = self.conv_in(hidden_state)
+
+        # middle
+        hidden_state = self.mid(hidden_state)
+
+        # upsampling
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks + 1):
+                hidden_state = self.up[i_level].block[i_block](hidden_state)
+                if len(self.up[i_level].attn) > 0:
+                    hidden_state = self.up[i_level].attn[i_block](hidden_state)
+            if i_level != self.num_resolutions - 1:
+                hidden_state = self.up[i_level].upsample(hidden_state)
+
+        hidden_state = self.norm_out(hidden_state)
+        hidden_state *= torch.sigmoid(hidden_state)
+        hidden_state = self.conv_out(hidden_state)
+        return hidden_state
+
+
+class JanusVQVAE(ChameleonVQVAE):
+    """Vision Transformer-based VQ-VAE model for encoding and decoding pixel values."""
+
+    _no_split_modules = [
+        "JanusVQVAEAttnBlock",
+        "JanusVQVAEResnetBlock",
+        "JanusVQVAEVectorQuantizer",
+    ]
+    main_input_name = "pixel_values"
+
+    def __init__(self, config: JanusVQVAEConfig):
+        super().__init__(config)
+        self.decoder = JanusVQVAEDecoder(config)
+        self.gradient_checkpointing = False
+
+        # Initialize the VQVAE model.
+        self.post_init()
+
+    def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
+        """
+        Decodes quantized token IDs into pixel values.
+        Args:
+            image_tokens (torch.LongTensor): Batch of token IDs.
+        Returns:
+            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+                Pixel values decoded from the token IDs.
+        """
+        if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]:
+            raise ValueError(
+                f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, "
+                f"but got shape `{image_tokens.shape}`."
+            )
+        codebook_entry = self.quantize.get_codebook_entry(image_tokens)
+        hidden_states = self.post_quant_conv(codebook_entry)
+        pixel_values = self.decoder(hidden_states)
+        return pixel_values
+
+    @can_return_tuple
+    def forward(
+        self,
+        pixel_values: torch.FloatTensor,
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+        """
+        Encodes pixel values into quantized tokens and decodes them back.
+        Args:
+            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+                The tensors corresponding to the input images.
+        Returns:
+            decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+                Reconstructed pixel values after encoding and decoding the input.
+            embedding_loss (`torch.FloatTensor`): Embedding loss.
+        """
+
+        batch_size = pixel_values.shape[0]
+        quant, embedding_loss, indices = self.encode(pixel_values)
+        decoded_pixel_values = self.decode(indices.view(batch_size, -1))
+        output = JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
+
+        return output
+
+
+class JanusVQVAEAlignerMLP(nn.Module):
+    def __init__(self, config: JanusVQVAEConfig):
+        super().__init__()
+
+        self.fc1 = nn.Linear(config.embed_dim, config.projection_dim)
+        self.hidden_layers = nn.ModuleList(
+            [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)]
+        )
+        self.activation_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, hidden_states):
+        hidden_states = self.fc1(hidden_states)
+        for layer in self.hidden_layers:
+            hidden_states = self.activation_fn(hidden_states)
+            hidden_states = layer(hidden_states)
+        return hidden_states
+
+
+class JanusVQVAEHead(nn.Module):
+    """Head used for sampling tokens in image generation, replacing the usual lm head."""
+
+    def __init__(self, config: JanusVQVAEConfig):
+        super().__init__()
+        self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim)
+        self.activation_fn = ACT2FN[config.hidden_act]
+        self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.tensor:
+        hidden_states = self.proj_out(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.vision_head(hidden_states)
+        return hidden_states
+
+
+JANUS_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+            The tensors corresponding to the input images. Pixel values can be obtained using
+            [`AutoImageProcessor`].
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+            information on the default strategy.
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+            the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+    """The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.""",
+    JANUS_START_DOCSTRING,
+)
+class JanusModel(JanusPreTrainedModel):
+    def __init__(self, config: JanusConfig):
+        super().__init__(config)
+        self.config = config
+        # This is necessary for backward compatibility, see SiglipModel initialization
+        self.vision_model = JanusVisionModel._from_config(config.vision_config)
+        self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
+
+        self.vqmodel = JanusVQVAE._from_config(config.vq_config)
+
+        # Below generation_* modules are used for Image generation.
+        # Embeddings used for image generation, instead of Janus vision embeddings.
+        self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim)
+        self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
+        self.generation_head = JanusVQVAEHead(self.vqmodel.config)
+
+        self.language_model = AutoModel.from_config(config=config.text_config)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing.
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.language_model.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.language_model.set_input_embeddings(value)
+
+    def get_image_features(self, pixel_values):
+        image_embeds = self.vision_model(pixel_values)
+        image_embeds = self.aligner(image_embeds.last_hidden_state)
+        return image_embeds
+
+    @can_return_tuple
+    @add_start_docstrings_to_model_forward(JANUS_INPUTS_DOCSTRING)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        pixel_values: torch.FloatTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        logits_to_keep: Union[int, torch.Tensor] = 0,
+        **kwargs,
+    ):
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        if (input_ids is None) ^ (inputs_embeds is not None):
+            raise ValueError(
+                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+            )
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        if pixel_values is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+            )
+
+        if inputs_embeds is None:
+            inputs_embeds = self.get_input_embeddings()(input_ids)
+
+        if pixel_values is not None:
+            image_embeds = self.get_image_features(pixel_values)
+            image_attention_mask = input_ids == self.config.image_token_index
+
+            embed_dim = inputs_embeds.shape[-1]
+            image_features = image_embeds.reshape(-1, embed_dim)
+            image_attention_mask = image_attention_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
+
+            image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+            inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
+
+        lm_output = self.language_model(
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            cache_position=cache_position,
+            logits_to_keep=logits_to_keep,
+            **kwargs,
+        )
+
+        output = JanusBaseModelOutputWithPast(
+            last_hidden_state=lm_output.last_hidden_state,
+            past_key_values=lm_output.past_key_values,
+            hidden_states=lm_output.hidden_states,
+            attentions=lm_output.attentions,
+            image_hidden_states=image_embeds if pixel_values is not None else None,
+        )
+
+        return output
+
+
+class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
+    _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
+    _supports_static_cache = True
+
+    def __init__(self, config: JanusConfig):
+        super().__init__(config)
+        self.config = config
+        self.model = JanusModel(config)
+        self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing.
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.language_model.get_input_embeddings()
+
+    def set_input_embeddings(self, value):
+        self.model.language_model.set_input_embeddings(value)
+
+    def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor:
+        hidden_state = self.model.generation_embeddings(inputs)
+        hidden_state = self.model.generation_aligner(hidden_state)
+        return hidden_state
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    @can_return_tuple
+    @add_start_docstrings_to_model_forward(JANUS_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=JanusCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        pixel_values: torch.FloatTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        cache_position: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        logits_to_keep: Union[int, torch.Tensor] = 0,
+        **kwargs,
+    ):
+        r"""
+        Args:
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+            logits_to_keep (`int` or `torch.Tensor`, *optional*):
+                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+                This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+        Returns:
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+
+        outputs = self.model(
+            input_ids=input_ids,
+            pixel_values=pixel_values,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            cache_position=cache_position,
+            **kwargs,
+        )
+        hidden_states = outputs.last_hidden_state
+        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+        logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
+
+        output = JanusCausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            image_hidden_states=outputs.image_hidden_states,
+        )
+        return output
+
+    def prepare_inputs_for_generation(
+        self,
+        input_ids,
+        pixel_values=None,
+        past_key_values=None,
+        attention_mask=None,
+        inputs_embeds=None,
+        cache_position=None,
+        logits_to_keep=None,
+        **kwargs,
+    ):
+        # Overwritten -- extra custom processing
+
+        model_inputs = super().prepare_inputs_for_generation(
+            input_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
+            cache_position=cache_position,
+            logits_to_keep=logits_to_keep,
+            **kwargs,
+        )
+
+        # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+        # Otherwise we need pixel values to be passed to model
+        if cache_position[0] == 0:
+            model_inputs["pixel_values"] = pixel_values
+
+        return model_inputs
+
+    def decode_image_tokens(self, image_tokens: torch.Tensor):
+        """
+        Decodes generated image tokens from language model to continuous pixel values
+        with VQGAN module via upsampling.
+        Args:
+            image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
+                The tensors corresponding to the input images.
+        """
+        decoded_image = self.model.vqmodel.decode(image_tokens)
+        decoded_image = decoded_image.permute(0, 2, 3, 1)
+        return decoded_image
+
+    @torch.no_grad
+    def generate(
+        self,
+        inputs: torch.Tensor = None,
+        attention_mask: Optional[torch.LongTensor] = None,
+        logits_processor: Optional[LogitsProcessorList] = None,
+        **kwargs,
+    ):
+        # 1. Handle generation config and model kwargs
+        generation_config = kwargs.pop("generation_config", self.generation_config)
+        generation_config = copy.deepcopy(generation_config)
+
+        # Default to "text" generation if mode isn't provided
+        generation_mode = kwargs.pop("generation_mode", "text")
+        if generation_mode == "text":
+            # Set guidance_scale=None to prevent running UnbatchedCFG processor.
+            return super().generate(
+                inputs=inputs,
+                attention_mask=attention_mask,
+                generation_config=generation_config,
+                guidance_scale=None,
+                **kwargs,
+            )
+
+        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
+
+        # Validate generation mode
+        if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
+            raise ValueError(
+                "Got incompatible mode for Image Generation, should be one of greedy or sampling. "
+                "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
+            )
+
+        # Validate the configuration and model kwargs
+        generation_config.validate()
+        self._validate_model_kwargs(model_kwargs.copy())
+
+        # 2. Initialize logit processors
+        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+
+        # Set `use_cache=True` as we will be using input embeds for generation.
+        model_kwargs["use_cache"] = True
+
+        if generation_config.guidance_scale is None:
+            logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.")
+            generation_config.guidance_scale = 5
+        model_kwargs["guidance_scale"] = generation_config.guidance_scale
+
+        # 3. Prepare model inputs
+        input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
+            inputs, generation_config.bos_token_id, model_kwargs
+        )
+        dtype, device = input_ids.dtype, input_ids.device
+
+        if len(input_ids.shape) != 2:
+            raise ValueError(
+                f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}"
+                "Passing `inputs embeds` is not supported currently."
+            )
+
+        # Prepare special tokens which will be used generate internally.
+        kwargs_has_attention_mask = attention_mask is not None
+        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
+
+        # 4. Add CFG processor along with user passed logit processor.
+        if generation_config.guidance_scale and generation_config.guidance_scale > 1:
+            logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
+            generation_config.guidance_scale = None  # Reset to prevent processor duplication.
+
+        # 5. Prepare logits processor
+        logits_processor = self._get_logits_processor(
+            generation_config=generation_config,
+            input_ids_seq_length=input_ids.shape[1],
+            encoder_input_ids=input_ids,
+            prefix_allowed_tokens_fn=None,
+            logits_processor=logits_processor,
+            device=device,
+        )
+
+        # 6. Expand inputs for multiple image generations per prompt.
+        input_ids, model_kwargs = self._expand_inputs_for_generation(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            expand_size=generation_config.num_return_sequences,
+            **model_kwargs,
+        )
+
+        # 7. Prepare input and model caches
+        num_image_tokens = self.model.vision_model.config.num_image_tokens
+        batch_size, seq_len = input_ids.shape
+
+        input_tokens = input_ids.repeat(2, 1)  # Double batch size for conditional/unconditional logits
+        attention_mask = model_kwargs.pop("attention_mask", None)
+        attention_mask = attention_mask.repeat(2, 1)
+        model_kwargs["attention_mask"] = attention_mask
+
+        # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
+        mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
+            input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
+        )
+        input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)
+
+        inputs_embeds = self.get_input_embeddings()(input_tokens)
+
+        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
+
+        if model_kwargs.get("past_key_values", None) is None:
+            # Prepare cache if not provided.
+            model_kwargs["past_key_values"] = self._get_cache(
+                cache_implementation=generation_config.cache_implementation or "static",
+                # batch_size should account for both conditional/unconditional input; hence multiplied by 2.
+                batch_size=batch_size * 2,
+                # we should have at least a cache len of seq_len + num_image_tokens.
+                max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
+                device=device,
+                model_kwargs=model_kwargs,
+            )
+
+        # Placeholder for generated tokens.
+        generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device)
+
+        # 8. init attention / hidden states / scores tuples
+        output_attentions = generation_config.output_attentions
+        output_hidden_states = generation_config.output_hidden_states
+        output_scores = generation_config.output_scores
+        output_logits = generation_config.output_logits
+        return_dict_in_generate = generation_config.return_dict_in_generate
+
+        raw_scores = () if (return_dict_in_generate and output_scores) else None
+        raw_logits = () if (return_dict_in_generate and output_logits) else None
+        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+
+        for i in range(num_image_tokens):
+            model_inputs = self.prepare_inputs_for_generation(
+                inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
+            )
+
+            model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
+            model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)
+
+            outputs = self.model.language_model(
+                **model_inputs,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+            )
+
+            # Update model_kwargs like cache_position for next generation.
+            model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
+            hidden_state = outputs.last_hidden_state[:, -1, :].clone()
+
+            # Generate scores using the generation head (Not using above defined LM Head)
+            scores = self.model.generation_head(hidden_state)
+            next_token_scores = logits_processor(input_ids, scores)
+
+            # Sample next token.
+            if generation_config.do_sample:
+                probs = torch.softmax(next_token_scores, dim=-1)
+                next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
+            else:
+                next_token = torch.argmax(next_token_scores, dim=-1)
+
+            generated_tokens[:, i] = next_token
+
+            # Prepare embeddings for the next step.
+            next_token = torch.cat([next_token, next_token])
+            next_token = next_token.unsqueeze(-1)
+
+            inputs_embeds = self.prepare_embeddings_for_image_generation(next_token)
+
+        if return_dict_in_generate:
+            if output_scores:
+                raw_scores += (scores,)
+            if output_logits:
+                raw_logits += (hidden_state.float(),)
+            if output_attentions:
+                decoder_attentions += outputs.attentions
+            if output_hidden_states:
+                decoder_hidden_states += outputs.hidden_states
+
+        if return_dict_in_generate:
+            return GenerateDecoderOnlyOutput(
+                sequences=generated_tokens,
+                scores=scores,
+                logits=raw_logits,
+                attentions=decoder_attentions,
+                hidden_states=decoder_hidden_states,
+                past_key_values=outputs.past_key_values,
+            )
+        else:
+            return generated_tokens
+
+
+class JanusImageProcessor(BlipImageProcessor):
+    r"""
+    Constructs a JANUS image processor.
+
+    Args:
+        do_resize (`bool`, *optional*, defaults to `True`):
+            Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+            `do_resize` parameter in the `preprocess` method.
+        size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
+            Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+            method.
+        min_size (`int`, *optional*, defaults to 14):
+            The minimum allowed size for the resized image. Ensures that neither the height nor width
+            falls below this value after resizing.
+        resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+            Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
+            overridden by the `resample` parameter in the `preprocess` method.
+        do_rescale (`bool`, *optional*, defaults to `True`):
+            Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+            `do_rescale` parameter in the `preprocess` method.
+        rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+            Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
+            overridden by the `rescale_factor` parameter in the `preprocess` method.
+        do_normalize (`bool`, *optional*, defaults to `True`):
+            Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+            method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+        image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+            Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+            channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+            overridden by the `image_mean` parameter in the `preprocess` method.
+        image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+            Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+            number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+            Can be overridden by the `image_std` parameter in the `preprocess` method.
+        do_convert_rgb (`bool`, *optional*, defaults to `True`):
+            Whether to convert the image to RGB.
+    """
+
+    def __init__(
+        self,
+        do_resize: bool = True,
+        size: Dict[str, int] = None,
+        min_size: int = 14,
+        resample: PILImageResampling = PILImageResampling.BICUBIC,
+        do_rescale: bool = True,
+        rescale_factor: Union[int, float] = 1 / 255,
+        do_normalize: bool = True,
+        image_mean: Optional[Union[float, List[float]]] = None,
+        image_std: Optional[Union[float, List[float]]] = None,
+        do_convert_rgb: bool = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.min_size = min_size
+        if image_mean is None:
+            self.background_color = (127, 127, 127)
+        else:
+            self.background_color = tuple([int(x * 255) for x in image_mean])
+
+    def pad_to_square(
+        self,
+        image: np.ndarray,
+        background_color: Union[int, Tuple[int, int, int]] = 0,
+        data_format: Optional[Union[str, ChannelDimension]] = None,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+    ) -> np.array:
+        """
+        Pads an image to a square based on the longest edge.
+
+        Args:
+            image (`np.ndarray`):
+                The image to pad.
+            background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
+                The color to use for the padding. Can be an integer for single channel or a
+                tuple of integers representing for multi-channel images. If passed as integer
+                in mutli-channel mode, it will default to `0` in subsequent channels.
+            data_format (`str` or `ChannelDimension`, *optional*):
+                The channel dimension format for the output image. Can be one of:
+                    - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                    - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                If unset, will use same as the input image.
+            input_data_format (`str` or `ChannelDimension`, *optional*):
+                The channel dimension format for the input image. Can be one of:
+                    - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                    - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+
+        Returns:
+            `np.ndarray`: The padded image.
+        """
+        height, width = get_image_size(image, input_data_format)
+        num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
+
+        if height == width:
+            image = (
+                to_channel_dimension_format(image, data_format, input_data_format)
+                if data_format is not None
+                else image
+            )
+            return image
+
+        max_dim = max(height, width)
+
+        # Ensure background_color is the correct shape
+        if isinstance(background_color, int):
+            background_color = [background_color]
+        elif len(background_color) != num_channels:
+            raise ValueError(
+                f"background_color must have no more than {num_channels} elements to match the number of channels"
+            )
+
+        if input_data_format == ChannelDimension.FIRST:
+            result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
+            for i, color in enumerate(background_color):
+                result[i, :, :] = color
+            if width > height:
+                start = (max_dim - height) // 2
+                result[:, start : start + height, :] = image
+            else:
+                start = (max_dim - width) // 2
+                result[:, :, start : start + width] = image
+        else:
+            result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
+            for i, color in enumerate(background_color):
+                result[:, :, i] = color
+            if width > height:
+                start = (max_dim - height) // 2
+                result[start : start + height, :, :] = image
+            else:
+                start = (max_dim - width) // 2
+                result[:, start : start + width, :] = image
+
+        return result
+
+    def resize(
+        self,
+        image: np.ndarray,
+        size: Union[Dict[str, int], int],
+        resample: PILImageResampling = PILImageResampling.BICUBIC,
+        data_format: Optional[Union[str, ChannelDimension]] = None,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+        **kwargs,
+    ) -> np.ndarray:
+        """
+        Resize an image to dynamically calculated size.
+
+        Args:
+            image (`np.ndarray`):
+                Image to resize.
+            resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+                `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
+            data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the output image. If unset, the channel dimension format of the input
+                image is used. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `None`: will be inferred from input
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+        Returns:
+            `np.ndarray`: The resized image.
+        """
+        if input_data_format is None:
+            input_data_format = infer_channel_dimension_format(image)
+
+        height, width = get_image_size(image, input_data_format)
+        max_size = max(height, width)
+
+        size = get_size_dict(size, default_to_square=True)
+        if size["height"] != size["width"]:
+            raise ValueError(
+                f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
+            )
+        size = size["height"]
+
+        delta = size / max_size
+        # Largest side becomes `size` and the other side is scaled according to the aspect ratio.
+        output_size_nonpadded = [
+            max(int(height * delta), self.min_size),
+            max(int(width * delta), self.min_size),
+        ]
+
+        image = resize(
+            image,
+            size=output_size_nonpadded,
+            resample=resample,
+            data_format=data_format,
+            input_data_format=input_data_format,
+            **kwargs,
+        )
+        # Expand and pad the images to obtain a square image of dimensions `size x size`
+        image = self.pad_to_square(
+            image=image,
+            background_color=self.background_color,
+            input_data_format=input_data_format,
+        )
+        return image
+
+    def postprocess(
+        self,
+        images: ImageInput,
+        do_rescale: bool = None,
+        rescale_factor: float = None,
+        do_normalize: bool = None,
+        image_mean: List[float] = None,
+        image_std: List[float] = None,
+        input_data_format: str = None,
+        return_tensors: str = None,
+    ):
+        """Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing."""
+        do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+        rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor
+        do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+        image_mean = image_mean if image_mean is not None else self.image_mean
+        image_std = image_std if image_std is not None else self.image_std
+
+        images = make_list_of_images(images)  # Ensures input is a list
+
+        if isinstance(images[0], PIL.Image.Image):
+            return images if len(images) > 1 else images[0]
+
+        if input_data_format is None:
+            input_data_format = infer_channel_dimension_format(images[0])  # Determine format dynamically
+
+        pixel_values = []
+
+        for image in images:
+            image = to_numpy_array(image)  # Ensure NumPy format
+
+            if do_normalize:
+                image = self.unnormalize(
+                    image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format
+                )
+
+            if do_rescale:
+                image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
+                image = image.clip(0, 255).astype(np.uint8)
+
+            if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
+                image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
+                image = PIL.Image.fromarray(image)
+
+            pixel_values.append(image)
+
+        data = {"pixel_values": pixel_values}
+        return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
+
+        return BatchFeature(data=data, tensor_type=return_tensors)
+
+    def unnormalize(
+        self,
+        image: np.array,
+        image_mean: Union[float, Iterable[float]],
+        image_std: Union[float, Iterable[float]],
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+    ) -> np.array:
+        """
+        Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`.
+        image = (image * image_std) + image_mean
+        Args:
+            image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`):
+                Batch of pixel values to postprocess.
+            image_mean (`float` or `Iterable[float]`):
+                The mean to use for unnormalization.
+            image_std (`float` or `Iterable[float]`):
+                The standard deviation to use for unnormalization.
+            input_data_format (`ChannelDimension` or `str`, *optional*):
+                The channel dimension format for the input image. If unset, the channel dimension format is inferred
+                from the input image. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+                - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+        """
+        num_channels = 3
+
+        if isinstance(image_mean, Iterable):
+            if len(image_mean) != num_channels:
+                raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}")
+        else:
+            image_mean = [image_mean] * num_channels
+
+        if isinstance(image_std, Iterable):
+            if len(image_std) != num_channels:
+                raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}")
+        else:
+            image_std = [image_std] * num_channels
+
+        rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std))
+        rev_image_std = tuple(1 / std for std in image_std)
+        image = self.normalize(
+            image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format
+        )
+        return image
+
+
+__all__ = [
+    "JanusImageProcessor",
+    "JanusPreTrainedModel",
+    "JanusForConditionalGeneration",
+    "JanusModel",
+    "JanusVQVAE",
+    "JanusVisionModel",
+    "JanusVQVAEConfig",
+    "JanusVisionConfig",
+    "JanusConfig",
+]
diff --git a/src/transformers/models/janus/processing_janus.py b/src/transformers/models/janus/processing_janus.py
new file mode 100644
index 00000000000..4132ca8f43d
--- /dev/null
+++ b/src/transformers/models/janus/processing_janus.py
@@ -0,0 +1,188 @@
+# coding=utf-8
+# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Janus.
+"""
+
+from typing import List, Union
+
+from ...feature_extraction_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
+from ...tokenization_utils_base import (
+    PreTokenizedInput,
+    TextInput,
+)
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+DEFAULT_SYSTEM_PROMPT = (
+    "You are a helpful language and vision assistant. "
+    "You are able to understand the visual content that the user provides, "
+    "and assist the user with a variety of tasks using natural language.\n\n"
+)
+
+
+class JanusTextKwargs(TextKwargs, total=False):
+    generation_mode: str
+
+
+class JanusProcessorKwargs(ProcessingKwargs, total=False):
+    text_kwargs: JanusTextKwargs
+    _defaults = {
+        "text_kwargs": {"padding": False, "generation_mode": "text"},
+        "common_kwargs": {"return_tensors": "pt"},
+    }
+
+
+class JanusProcessor(ProcessorMixin):
+    r"""
+    Constructs a Janus processor which wraps a Janus Image Processor and a Llama tokenizer into a single processor.
+
+    [`JanusProcessor`] offers all the functionalities of [`JanusImageProcessor`] and [`LlamaTokenizerFast`]. See the
+    [`~JanusProcessor.__call__`] and [`~JanusProcessor.decode`] for more information.
+
+    Args:
+        image_processor ([`JanusImageProcessor`]):
+            The image processor is a required input.
+        tokenizer ([`LlamaTokenizerFast`]):
+            The tokenizer is a required input.
+        chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+            in a chat into a tokenizable string.
+        use_default_system_prompt (`str`, *optional*, defaults to `False`):
+            Use default system prompt for Text Generation.
+    """
+
+    attributes = ["image_processor", "tokenizer"]
+    valid_kwargs = ["chat_template", "use_default_system_prompt"]
+    image_processor_class = "JanusImageProcessor"
+    tokenizer_class = "LlamaTokenizerFast"
+
+    def __init__(self, image_processor, tokenizer, chat_template=None, use_default_system_prompt=False, **kwargs):
+        self.num_image_tokens = 576
+        self.image_token = tokenizer.image_token
+        self.image_start_token = tokenizer.boi_token
+        self.image_end_token = tokenizer.eoi_token
+        self.use_default_system_prompt = use_default_system_prompt
+
+        super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+    def __call__(
+        self,
+        text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+        images: ImageInput = None,
+        videos=None,
+        audio=None,
+        **kwargs: Unpack[JanusProcessorKwargs],
+    ) -> BatchFeature:
+        """
+        Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+        and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
+        the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
+        JanusImageProcessor's [`~JanusImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
+        of the above two methods for more information.
+
+        Args:
+            text (`str`, `List[str]`, `List[List[str]]`):
+                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+                tensor. Both channels-first and channels-last formats are supported.
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors of a particular framework. Acceptable values are:
+                - `'tf'`: Return TensorFlow `tf.constant` objects.
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                - `'np'`: Return NumPy `np.ndarray` objects.
+                - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+        Returns:
+            [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+            - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+              `None`).
+            - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+        """
+
+        output_kwargs = self._merge_kwargs(
+            JanusProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs
+        )
+
+        if text is None and images is None:
+            raise ValueError("You must specify either text or images.")
+
+        if text is not None:
+            if isinstance(text, str):
+                text = [text]
+            elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
+                raise ValueError("Invalid input text. Please provide a string, or a list of strings")
+
+        generation_mode = output_kwargs["text_kwargs"].pop("generation_mode")
+
+        # Replace the image token with expanded image tokens.
+        prompt_strings = []
+        one_img_tokens = self.image_start_token + (self.image_token * self.num_image_tokens) + self.image_end_token
+        for prompt in text:
+            prompt = prompt.replace(self.image_token, one_img_tokens)
+            if self.use_default_system_prompt and generation_mode == "text":
+                prompt = DEFAULT_SYSTEM_PROMPT + prompt
+            if generation_mode == "image":
+                prompt += self.image_start_token
+            prompt_strings.append(prompt)
+
+        data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
+
+        # Process images if pixel values are provided.
+        if images is not None and generation_mode != "image":
+            data["pixel_values"] = self.image_processor(images=images, **output_kwargs["images_kwargs"])[
+                "pixel_values"
+            ]
+
+        return BatchFeature(data=data)
+
+    def batch_decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+        refer to the docstring of this method for more information.
+        """
+        return self.tokenizer.batch_decode(*args, **kwargs)
+
+    def decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+        the docstring of this method for more information.
+        """
+        return self.tokenizer.decode(*args, **kwargs)
+
+    def postprocess(self, images: ImageInput, **kwargs):
+        """
+        Forwards all arguments to the image processor's `postprocess` method.
+        Refer to the original method's docstring for more details.
+        """
+        return self.image_processor.postprocess(images, **kwargs)
+
+    @property
+    def model_input_names(self):
+        tokenizer_input_names = self.tokenizer.model_input_names
+        image_processor_input_names = self.image_processor.model_input_names
+        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+__all__ = ["JanusProcessor"]
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index c738ca46457..0672589769a 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -126,6 +126,7 @@ VLM_CLASS_NAMES = [
     "qwen2vl",
     "qwen2_5_vl",
     "ayavision",
+    "janus",
     "gemma3",
     "mistral3",
     "chameleon",
diff --git a/tests/models/janus/__init__.py b/tests/models/janus/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tests/models/janus/test_image_processing_janus.py b/tests/models/janus/test_image_processing_janus.py
new file mode 100644
index 00000000000..184f669e6a5
--- /dev/null
+++ b/tests/models/janus/test_image_processing_janus.py
@@ -0,0 +1,188 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_vision_available
+
+from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+    import torch
+
+if is_vision_available():
+    from PIL import Image
+
+    from transformers import JanusImageProcessor
+
+
+class JanusImageProcessingTester:
+    def __init__(
+        self,
+        parent,
+        batch_size=7,
+        num_channels=3,
+        image_size=384,
+        min_resolution=30,
+        max_resolution=200,
+        do_resize=True,
+        size=None,
+        do_normalize=True,
+        image_mean=[1.0, 1.0, 1.0],
+        image_std=[1.0, 1.0, 1.0],
+        do_convert_rgb=True,
+    ):
+        size = size if size is not None else {"height": 384, "width": 384}
+        self.parent = parent
+        self.batch_size = batch_size
+        self.num_channels = num_channels
+        self.image_size = image_size
+        self.min_resolution = min_resolution
+        self.max_resolution = max_resolution
+        self.do_resize = do_resize
+        self.size = size
+        self.do_normalize = do_normalize
+        self.image_mean = image_mean
+        self.image_std = image_std
+        self.do_convert_rgb = do_convert_rgb
+
+    def prepare_image_processor_dict(self):
+        return {
+            "do_resize": self.do_resize,
+            "size": self.size,
+            "do_normalize": self.do_normalize,
+            "image_mean": self.image_mean,
+            "image_std": self.image_std,
+            "do_convert_rgb": self.do_convert_rgb,
+        }
+
+    # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
+    def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
+        return prepare_image_inputs(
+            batch_size=self.batch_size,
+            num_channels=self.num_channels,
+            min_resolution=self.min_resolution,
+            max_resolution=self.max_resolution,
+            equal_resolution=equal_resolution,
+            numpify=numpify,
+            torchify=torchify,
+        )
+
+
+@require_torch
+@require_vision
+class JanusImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
+    image_processing_class = JanusImageProcessor if is_vision_available() else None
+
+    # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Janus
+    def setUp(self):
+        super().setUp()
+        self.image_processor_tester = JanusImageProcessingTester(self)
+
+    @property
+    # Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
+    def image_processor_dict(self):
+        return self.image_processor_tester.prepare_image_processor_dict()
+
+    def test_image_processor_properties(self):
+        image_processing = self.image_processing_class(**self.image_processor_dict)
+        self.assertTrue(hasattr(image_processing, "do_resize"))
+        self.assertTrue(hasattr(image_processing, "size"))
+        self.assertTrue(hasattr(image_processing, "do_normalize"))
+        self.assertTrue(hasattr(image_processing, "image_mean"))
+        self.assertTrue(hasattr(image_processing, "image_std"))
+        self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
+
+    def test_image_processor_from_dict_with_kwargs(self):
+        image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
+        self.assertEqual(image_processor.size, {"height": 384, "width": 384})
+        self.assertEqual(image_processor.image_mean, [1.0, 1.0, 1.0])
+
+        image_processor = self.image_processing_class.from_dict(
+            self.image_processor_dict, size=42, image_mean=[1.0, 2.0, 1.0]
+        )
+        self.assertEqual(image_processor.size, {"height": 42, "width": 42})
+        self.assertEqual(image_processor.image_mean, [1.0, 2.0, 1.0])
+
+    def test_call_pil(self):
+        image_processing = self.image_processing_class(**self.image_processor_dict)
+        image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
+        for image in image_inputs:
+            self.assertIsInstance(image, Image.Image)
+
+        # Test Non batched input
+        encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+        expected_output_image_shape = (1, 3, 384, 384)
+        self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
+
+        # Test batched
+        encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+        expected_output_image_shape = (7, 3, 384, 384)
+        self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
+
+    def test_call_numpy(self):
+        image_processing = self.image_processing_class(**self.image_processor_dict)
+        image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
+        for image in image_inputs:
+            self.assertIsInstance(image, np.ndarray)
+
+        encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+        expected_output_image_shape = (1, 3, 384, 384)
+        self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
+
+        encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+        expected_output_image_shape = (7, 3, 384, 384)
+        self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
+
+    def test_call_pytorch(self):
+        image_processing = self.image_processing_class(**self.image_processor_dict)
+        image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
+
+        for image in image_inputs:
+            self.assertIsInstance(image, torch.Tensor)
+
+        encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
+        expected_output_image_shape = (1, 3, 384, 384)
+        self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
+
+        encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+        expected_output_image_shape = (7, 3, 384, 384)
+        self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
+
+    def test_nested_input(self):
+        image_processing = self.image_processing_class(**self.image_processor_dict)
+        image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
+
+        # Test batched as a list of images.
+        encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
+        expected_output_image_shape = (7, 3, 384, 384)
+        self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
+
+        # Test batched as a nested list of images, where each sublist is one batch.
+        image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
+        encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
+        expected_output_image_shape = (7, 3, 384, 384)
+        self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
+
+        # Image processor should return same pixel values, independently of input format.
+        self.assertTrue((encoded_images_nested == encoded_images).all())
+
+    @unittest.skip(reason="Not supported")
+    def test_call_numpy_4_channels(self):
+        pass
diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py
new file mode 100644
index 00000000000..03208f388e8
--- /dev/null
+++ b/tests/models/janus/test_modeling_janus.py
@@ -0,0 +1,553 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Janus model."""
+
+import re
+import tempfile
+import unittest
+from functools import reduce
+
+import numpy as np
+import requests
+
+from transformers import (
+    AutoProcessor,
+    JanusConfig,
+    JanusForConditionalGeneration,
+    JanusModel,
+    JanusVQVAE,
+    JanusVQVAEConfig,
+    is_torch_available,
+    is_vision_available,
+)
+from transformers.models.auto import get_values
+from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
+from transformers.testing_utils import (
+    require_torch,
+    slow,
+    torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+    import torch
+
+
+if is_vision_available():
+    from PIL import Image
+
+
+class JanusVisionText2TextModelTester:
+    def __init__(
+        self,
+        parent,
+        image_token_index=0,
+        seq_length=25,
+        initializer_range=0.02,
+        text_config={
+            "model_type": "llama",
+            "seq_length": 7,
+            "is_training": True,
+            "use_input_mask": True,
+            "use_token_type_ids": False,
+            "use_labels": True,
+            "vocab_size": 99,
+            "hidden_size": 32,
+            "num_hidden_layers": 2,
+            "num_attention_heads": 4,
+            "intermediate_size": 37,
+            "hidden_act": "gelu",
+            "hidden_dropout_prob": 0.1,
+            "attention_probs_dropout_prob": 0.1,
+            "max_position_embeddings": 512,
+            "type_vocab_size": 16,
+            "type_sequence_label_size": 2,
+            "initializer_range": 0.02,
+            "num_labels": 3,
+            "num_choices": 4,
+            "pad_token_id": 1,
+        },
+        is_training=True,
+        vision_config={
+            "use_labels": True,
+            "image_size": 20,
+            "patch_size": 5,
+            "num_image_tokens": 4,
+            "num_channels": 3,
+            "is_training": True,
+            "hidden_size": 32,
+            "projection_dim": 32,
+            "num_key_value_heads": 1,
+            "num_hidden_layers": 2,
+            "num_attention_heads": 4,
+            "mlp_ratio": 2,
+            "dropout": 0.1,
+            "attention_dropout": 0.1,
+            "initializer_range": 0.02,
+            "vision_feature_select_strategy": "default",
+            "vision_feature_layer": -1,
+        },
+        use_cache=False,
+        vq_num_embeds=12,
+        vq_embed_dim=12,
+        vq_channel_multiplier=[1, 1],
+    ):
+        self.parent = parent
+        self.initializer_range = initializer_range
+        # `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
+        self.image_token_index = image_token_index
+        self.text_config = text_config
+        self.vision_config = vision_config
+        self.seq_length = seq_length
+        self.pad_token_id = text_config["pad_token_id"]
+
+        self.num_hidden_layers = text_config["num_hidden_layers"]
+        self.vocab_size = text_config["vocab_size"]
+        self.hidden_size = text_config["hidden_size"]
+        self.num_attention_heads = text_config["num_attention_heads"]
+        self.is_training = is_training
+
+        self.batch_size = 3
+        self.num_channels = vision_config["num_channels"]
+        self.image_size = vision_config["image_size"]
+        self.num_image_tokens = vision_config["num_image_tokens"]
+        self.use_cache = use_cache
+
+        # vq model params
+        self.vq_num_embeds = vq_num_embeds
+        self.vq_embed_dim = vq_embed_dim
+        self.vq_channel_multiplier = vq_channel_multiplier
+
+    def get_vq_config(self):
+        return {
+            "embed_dim": self.vq_embed_dim,
+            "num_embeddings": self.vq_num_embeds,
+            "latent_channels": self.vq_embed_dim,
+            "in_channels": 3,
+            "base_channels": 32,  # we have a GroupNorm of 32 groups, so can't do less
+            "channel_multiplier": self.vq_channel_multiplier,
+            "initializer_range": self.initializer_range,
+            "projection_dim": 10,
+            "image_token_embed_dim": 32,  # Same as text model hidden size
+        }
+
+    def get_config(self):
+        return JanusConfig(
+            text_config=self.text_config,
+            vision_config=self.vision_config,
+            vq_config=self.get_vq_config(),
+        )
+
+    def prepare_config_and_inputs(self):
+        config = self.get_config()
+        pixel_values = floats_tensor(
+            [
+                self.batch_size,
+                3,
+                self.image_size,
+                self.image_size,
+            ]
+        )
+        return config, pixel_values
+
+    def prepare_config_and_inputs_for_common(self):
+        config_and_inputs = self.prepare_config_and_inputs()
+        config, pixel_values = config_and_inputs
+        input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
+        attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
+
+        # set the 16 first tokens to be image, and ensure that no other tokens are image tokens
+        # do not change this unless you modified image size or patch size
+        input_ids[input_ids == self.image_token_index] = self.pad_token_id
+        input_ids[:, : self.num_image_tokens] = self.image_token_index
+        inputs_dict = {
+            "pixel_values": pixel_values,
+            "input_ids": input_ids,
+            "attention_mask": attention_mask,
+            "labels": input_ids,
+            "generation_mode": "text",  # Required to perform text generation instead of image generation.
+        }
+        return config, inputs_dict
+
+
+@require_torch
+class JanusVisionText2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+    all_model_classes = (JanusModel, JanusForConditionalGeneration) if is_torch_available() else ()
+    all_generative_model_classes = (JanusForConditionalGeneration,) if is_torch_available() else ()
+    fx_compatible = False
+    test_pruning = False
+    test_head_masking = False
+    _is_composite = True
+
+    def setUp(self):
+        self.model_tester = JanusVisionText2TextModelTester(self)
+        self.config_tester = ConfigTester(self, config_class=JanusConfig, has_text_modality=False)
+
+    # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
+    def test_inputs_embeds(self):
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+        for model_class in self.all_model_classes:
+            model = model_class(config)
+            model.to(torch_device)
+            model.eval()
+
+            inputs = self._prepare_for_class(inputs_dict, model_class)
+
+            input_ids = inputs["input_ids"]
+            del inputs["input_ids"]
+            del inputs["pixel_values"]
+            del inputs["generation_mode"]
+
+            wte = model.get_input_embeddings()
+            inputs["inputs_embeds"] = wte(input_ids)
+
+            with torch.no_grad():
+                model(**inputs)
+
+    # Overwrite inputs_embeds tests because we need to delete "pixel values" for VLMs.
+    def test_inputs_embeds_matches_input_ids(self):
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+        for model_class in self.all_model_classes:
+            model = model_class(config)
+            model.to(torch_device)
+            model.eval()
+
+            inputs = self._prepare_for_class(inputs_dict, model_class)
+            input_ids = inputs["input_ids"]
+            del inputs["input_ids"]
+            del inputs["pixel_values"]
+            del inputs["generation_mode"]
+
+            inputs_embeds = model.get_input_embeddings()(input_ids)
+
+            with torch.no_grad():
+                out_ids = model(input_ids=input_ids, **inputs)[0]
+                out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
+            torch.testing.assert_close(out_embeds, out_ids)
+
+    def test_sdpa_can_dispatch_composite_models(self):
+        for model_class in self.all_model_classes:
+            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+            model = model_class(config)
+
+            with tempfile.TemporaryDirectory() as tmpdirname:
+                model.save_pretrained(tmpdirname)
+
+                # Load the model with SDPA
+                model_sdpa = model_class.from_pretrained(tmpdirname)
+                model_sdpa = model_sdpa.eval().to(torch_device)
+
+                # Load model with eager attention
+                model_eager = model_class.from_pretrained(
+                    tmpdirname,
+                    attn_implementation="eager",
+                )
+                model_eager = model_eager.eval().to(torch_device)
+
+            # SigLip has one shared cls attr for all models, so we assign both submodels heer
+            vision_attn = language_attn = "sdpa" if model._supports_sdpa else "eager"
+
+            if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "language_model"):
+                self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
+                self.assertTrue(model_sdpa.language_model.config._attn_implementation == language_attn)
+                self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
+                self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
+
+            self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
+            self.assertTrue(model_eager.config._attn_implementation == "eager")
+
+            for name, submodule in model_eager.named_modules():
+                class_name = submodule.__class__.__name__
+                if any(re.finditer(r"Attention(?!Pool)", class_name)):
+                    self.assertTrue(submodule.config._attn_implementation == "eager")
+
+            for name, submodule in model_sdpa.named_modules():
+                class_name = submodule.__class__.__name__
+                if any(re.finditer(r"Attention(?!Pool)", class_name)):
+                    self.assertTrue(submodule.config._attn_implementation == "sdpa")
+
+    def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
+        if not self.model_tester.is_training:
+            self.skipTest(reason="ModelTester is not configured to run training tests")
+        """
+        We skip some parameters when checking for gradient checkpointing:
+        - VQ model, as its training is not supported.
+        - A few other modules used for image generation.
+        """
+        skip_patterns = ["vqmodel", "generation_embeddings", "generation_aligner", "generation_head"]
+
+        for model_class in self.all_model_classes:
+            with self.subTest(model_class.__name__):
+                if (
+                    model_class.__name__
+                    in [
+                        *get_values(MODEL_MAPPING_NAMES),
+                        *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
+                    ]
+                    or not model_class.supports_gradient_checkpointing
+                ):
+                    # TODO (ydshieh): use `skipTest` once pytest-dev/pytest-subtests/pull/169 is merged
+                    # self.skipTest(reason=f"`supports_gradient_checkpointing` is False for {model_class.__name__}.")
+                    continue
+
+                config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+                config.use_cache = False
+                config.return_dict = True
+                model = model_class(config)
+
+                model.to(torch_device)
+                model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
+                model.train()
+
+                # unfreeze additional layers
+                for p in model.parameters():
+                    p.requires_grad_(True)
+
+                optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
+
+                inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
+                loss = model(**inputs).loss
+                loss.backward()
+                optimizer.step()
+
+                if self.test_all_params_have_gradient:
+                    for k, v in model.named_parameters():
+                        if v.requires_grad and not reduce(lambda t, s: t | (s in k), skip_patterns, False):
+                            self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
+                        else:
+                            pass
+
+
+class JanusVQModelTester:
+    def __init__(
+        self,
+        parent,
+        batch_size=5,
+        is_training=False,
+        initializer_range=0.02,
+        image_size=30,
+        num_embeds=12,
+        base_channels=32,  # we have a GroupNorm of 32 groups, so can't do less
+        embed_dim=12,
+        channel_multiplier=[1, 2],
+        patch_size=2,
+        scope=None,
+    ):
+        self.parent = parent
+        self.batch_size = batch_size
+        self.is_training = is_training
+        self.initializer_range = initializer_range
+        self.image_size = image_size
+        self.base_channels = base_channels
+        self.num_embeds = num_embeds
+        self.embed_dim = embed_dim
+        self.channel_multiplier = channel_multiplier
+        self.num_patches = image_size // patch_size
+
+    def prepare_config_and_inputs(self):
+        pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
+        config = self.get_config()
+        return config, pixel_values
+
+    def get_config(self):
+        return JanusVQVAEConfig(
+            embed_dim=self.embed_dim,
+            num_embeddings=self.num_embeds,
+            latent_channels=self.embed_dim,
+            in_channels=3,
+            base_channels=self.base_channels,
+            channel_multiplier=self.channel_multiplier,
+            initializer_range=self.initializer_range,
+            resolution=self.image_size,
+            num_patches=self.num_patches,
+        )
+
+    def prepare_config_and_inputs_for_common(self):
+        config_and_inputs = self.prepare_config_and_inputs()
+        config, pixel_values = config_and_inputs
+        inputs_dict = {"pixel_values": pixel_values}
+        return config, inputs_dict
+
+
+@require_torch
+class JanusVQModelTest(ModelTesterMixin, unittest.TestCase):
+    all_model_classes = (JanusVQVAE,) if is_torch_available() else ()
+    test_head_masking = False
+    test_pruning = False
+    fx_compatible = False
+    has_attentions = False
+    test_resize_embeddings = False
+
+    def setUp(self):
+        self.model_tester = JanusVQModelTester(self)
+        self.config_tester = ConfigTester(
+            self,
+            config_class=JanusVQVAEConfig,
+            has_text_modality=False,
+            common_properties=["embed_dim", "num_embeddings"],
+        )
+
+    def test_config(self):
+        self.config_tester.run_common_tests()
+
+    @unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
+    def test_cpu_offload(self):
+        pass
+
+    @unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
+    def test_disk_offload_bin(self):
+        pass
+
+    @unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
+    def test_disk_offload_safetensors(self):
+        pass
+
+    @unittest.skip("Janus VQ module has no hidden states")
+    def test_hidden_states_output(self):
+        pass
+
+    @unittest.skip("Janus VQ module has no hidden states")
+    def test_model_outputs_equivalence(self):
+        pass
+
+    @unittest.skip("Janus VQ module has no get/set embeddings method")
+    def test_model_get_set_embeddings(self):
+        pass
+
+    @unittest.skip("Janus VQ module has no hidden states")
+    def test_retain_grad_hidden_states_attentions(self):
+        pass
+
+
+class JanusIntegrationTest(unittest.TestCase):
+    def setUp(self):
+        self.model_id = "deepseek-community/Janus-Pro-1B"
+
+    @slow
+    def test_model_text_generation(self):
+        model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
+        model.eval()
+        processor = AutoProcessor.from_pretrained(self.model_id)
+        image = Image.open(
+            requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
+        )
+        prompt = "\nDescribe what do you see here and tell me about the history behind it?"
+        inputs = processor(images=image, text=prompt, generation_mode="text", return_tensors="pt").to(model.device)
+
+        output = model.generate(**inputs, max_new_tokens=20, generation_mode="text", do_sample=False)
+        EXPECTED_DECODED_TEXT = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\n\nDescribe what do you see here and tell me about the history behind it?\n\nThe image depicts the constellation of Leo, which is often referred to as the "Lion"'  # fmt: skip
+        text = processor.decode(output[0], skip_special_tokens=True)
+        self.assertEqual(
+            text,
+            EXPECTED_DECODED_TEXT,
+        )
+
+    @slow
+    def test_model_text_generation_batched(self):
+        model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
+        processor = AutoProcessor.from_pretrained(self.model_id)
+
+        image_1 = Image.open(
+            requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
+        )
+        image_2 = Image.open(
+            requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
+        )
+        prompts = [
+            "\nDescribe what do you see here and tell me about the history behind it?",
+            "What constellation is this image showing?\n",
+        ]
+
+        inputs = processor(
+            images=[image_1, image_2], text=prompts, generation_mode="text", padding=True, return_tensors="pt"
+        ).to(model.device, torch.float16)
+
+        EXPECTED_TEXT_COMPLETION = [
+            'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\n\nDescribe what do you see here and tell me about the history behind it?\n\nThe image depicts the constellation of Leo, which is often referred to as the "Lion"',
+            "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\nWhat constellation is this image showing?\n\nThe image shows a constellation that is shaped like a stylized figure with a long tail. This",
+        ]
+        generated_ids = model.generate(**inputs, max_new_tokens=20, generation_mode="text", do_sample=False)
+        text = processor.batch_decode(generated_ids, skip_special_tokens=True)
+        self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
+
+    @slow
+    def test_model_text_generation_with_multi_image(self):
+        model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
+        processor = AutoProcessor.from_pretrained(self.model_id)
+
+        image_1 = Image.open(
+            requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
+        )
+        image_2 = Image.open(
+            requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
+        )
+        prompt = "What do these two images  and  have in common?"
+
+        inputs = processor(images=[image_1, image_2], text=prompt, generation_mode="text", return_tensors="pt").to(
+            model.device, torch.float16
+        )
+
+        EXPECTED_TEXT_COMPLETION = ['You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\nWhat do these two images  and  have in common?\n\nThe two images you provided are of the same constellation. The first image shows the constellation of Leo, and the second image shows the constellation of Ursa Major. Both constellations are part of']  # fmt: skip
+        generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
+        text = processor.batch_decode(generated_ids, skip_special_tokens=True)
+        self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
+
+    @slow
+    def test_model_generate_images(self):
+        model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
+        processor = AutoProcessor.from_pretrained(self.model_id)
+
+        inputs = processor(
+            text=["A portrait of young girl. masterpiece, film grained, best quality."],
+            padding=True,
+            generation_mode="image",
+            return_tensors="pt",
+        ).to(model.device)
+
+        self.assertTrue(inputs.input_ids.shape[1] == 17)
+
+        out = model.generate(
+            **inputs,
+            generation_mode="image",
+            do_sample=False,
+        )
+
+        # It should run for num_image_tokens in this case 576.
+        self.assertTrue(out.shape[1] == 576)
+
+        # fmt: off
+        expected_tokens = torch.tensor([4484,  4015, 15750,   506,  3758, 11651,  8597,  5739,  4861,   971,
+         14985, 14834, 15438,  7548,  1820,  1465, 13529, 12761, 10503, 12761,
+         14303,  6155,  4015, 11766,   705, 15736, 14146, 10417,  1951,  7713,
+         14305, 15617,  6169,  2706,  8006, 14893,  3855, 10188, 15652,  6297,
+          1097, 12108, 15038,   311, 14998, 15165,   897,  4044,  1762,  4676,
+        ]).to(model.device)
+        # fmt: on
+
+        # Compare the first 50 generated tokens.
+        self.assertTrue(torch.allclose(expected_tokens, out[0][:50]))
+
+        # Decode generated tokens to pixel values and postprocess them.
+        decoded_pixel_values = model.decode_image_tokens(out)
+        images = processor.postprocess(list(decoded_pixel_values.float()), return_tensors="np")
+
+        self.assertTrue(images["pixel_values"].shape == (1, 384, 384, 3))
+        self.assertTrue(isinstance(images["pixel_values"], np.ndarray))
diff --git a/tests/models/janus/test_processor_janus.py b/tests/models/janus/test_processor_janus.py
new file mode 100644
index 00000000000..8b664bb7432
--- /dev/null
+++ b/tests/models/janus/test_processor_janus.py
@@ -0,0 +1,455 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Janus model."""
+
+import tempfile
+import unittest
+
+import numpy as np
+
+from transformers import AutoProcessor, AutoTokenizer, JanusProcessor
+from transformers.models.janus.convert_janus_weights_to_hf import CHAT_TEMPLATE
+from transformers.utils import is_vision_available
+
+from ...test_processing_common import ProcessorTesterMixin
+
+
+if is_vision_available():
+    pass
+
+
+class JanusProcessorTest(ProcessorTesterMixin, unittest.TestCase):
+    processor_class = JanusProcessor
+
+    def setUp(self):
+        self.tmpdirname = tempfile.mkdtemp()
+        special_image_tokens = {
+            "image_token": "",
+            "boi_token": "",
+            "eoi_token": "",
+        }
+
+        processor = self.processor_class.from_pretrained(
+            "deepseek-community/Janus-Pro-1B",
+            extra_special_tokens=special_image_tokens,
+        )
+        processor.save_pretrained(self.tmpdirname)
+
+    def get_tokenizer(self, **kwargs):
+        return AutoTokenizer.from_pretrained(self.tmpdirname, **kwargs)
+
+    def get_image_processor(self, **kwargs):
+        return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
+
+    def prepare_processor_dict(self):
+        # similar to Emu3 and Qwen2VLProcessorTest, but keep the template in the convert script to avoid duplicated code
+        return {
+            "chat_template": CHAT_TEMPLATE,
+        }
+
+    def test_chat_template_single(self):
+        """
+        Tests that the chat template matches the original implementation when applied to a single message.
+        """
+        processor = self.get_processor()
+        if processor.chat_template is None:
+            self.skipTest("Processor has no chat template")
+
+        # Single image message
+        messages = [
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "What is shown in this image?"},
+                        {"type": "image"},
+                    ],
+                },
+            ]
+        ]
+
+        correct_prompt = ["<|User|>: What is shown in this image?\n\n\n<|Assistant|>:"]
+
+        formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
+        self.assertEqual(formatted_prompt, correct_prompt)
+
+        # Single image message with capitalization
+        messages = [
+            [
+                {
+                    "role": "User",
+                    "content": [
+                        {"type": "text", "text": "What is shown in this image?"},
+                        {"type": "image"},
+                    ],
+                },
+            ]
+        ]
+
+        correct_prompt = ["<|User|>: What is shown in this image?\n\n\n<|Assistant|>:"]
+
+        formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
+        self.assertEqual(formatted_prompt, correct_prompt)
+
+        # Single image message with uppercase
+        messages = [
+            [
+                {
+                    "role": "USER",
+                    "content": [
+                        {"type": "text", "text": "What is shown in this image?"},
+                        {"type": "image"},
+                    ],
+                },
+            ]
+        ]
+
+        correct_prompt = ["<|User|>: What is shown in this image?\n\n\n<|Assistant|>:"]
+
+        formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
+        self.assertEqual(formatted_prompt, correct_prompt)
+
+        """
+        Warning: normally, the other models have a test comparing chat template+tokenization as two separate steps
+        versus as a single step (i.e. processor.apply_chat_template(..., tokenize=True)). However, our processor has
+        some extra steps other than simply applying prompt to tokenizer. These include prepending the default system
+        prompts and, following the implementation from the Janus codebase, expanding the image token.
+        """
+
+        # Checking the output dict keys
+        out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
+        self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
+
+        # Now test the ability to return dict
+        messages[0][0]["content"][1].update(
+            {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
+        )
+        out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
+        self.assertTrue(self.images_input_name in out_dict)
+        # should always have input_ids and attention_mask
+        self.assertEqual(len(out_dict["input_ids"]), 1)
+        self.assertEqual(len(out_dict["attention_mask"]), 1)
+        self.assertEqual(len(out_dict[self.images_input_name]), 1)
+
+        # Passing generation prompt explicitly
+        messages = [
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "What is shown in this image?"},
+                        {"type": "image"},
+                    ],
+                },
+                {
+                    "role": "assistant",
+                    "content": [
+                        {"type": "text", "text": ""},
+                    ],
+                },
+            ]
+        ]
+
+        formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=False)
+        self.assertEqual(formatted_prompt, correct_prompt)
+
+        # Single prompt with multiple images
+        messages = [
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "Compare this image"},
+                        {"type": "image"},
+                        {"type": "text", "text": "with this image"},
+                        {"type": "image"},
+                    ],
+                },
+            ]
+        ]
+
+        correct_prompt = [
+            "<|User|>: Compare this image\n\nwith this image\n\n\n<|Assistant|>:"
+        ]
+        formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
+        self.assertEqual(formatted_prompt, correct_prompt)
+
+        # Multiple turns and multiple images
+        messages = [
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "Compare this image"},
+                        {"type": "image"},
+                        {"type": "text", "text": "with this image"},
+                        {"type": "image"},
+                    ],
+                },
+                {
+                    "role": "assistant",
+                    "content": [
+                        {"type": "text", "text": "The first image is an equation, the second is a pie chart."},
+                    ],
+                },
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "image"},
+                        {
+                            "type": "text",
+                            "text": "What about this third image? To which of the previous to is it more similar?",
+                        },
+                    ],
+                },
+            ]
+        ]
+
+        correct_prompt = [
+            "<|User|>: Compare this image\n\nwith this image\n\n\n<|Assistant|>: The first image is an equation, the second is a pie chart.<|end▁of▁sentence|><|User|>: \nWhat about this third image? To which of the previous to is it more similar?\n\n<|Assistant|>:"
+        ]
+        formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
+        self.assertEqual(formatted_prompt, correct_prompt)
+
+    def test_chat_template_batched(self):
+        """
+        Tests that the chat template matches the original implementation when applied to a batch of messages.
+        """
+        processor = self.get_processor()
+        if processor.chat_template is None:
+            self.skipTest("Processor has no chat template")
+
+        # Test 1: Simple single image per message batch
+        batched_messages = [
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "What is shown in this image?"},
+                        {"type": "image"},
+                    ],
+                },
+            ],
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "What is shown in this image?"},
+                        {"type": "image"},
+                    ],
+                },
+            ],
+        ]
+
+        correct_prompts = [
+            "<|User|>: What is shown in this image?\n\n\n<|Assistant|>:",
+            "<|User|>: What is shown in this image?\n\n\n<|Assistant|>:",
+        ]
+
+        formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
+        self.assertEqual(formatted_prompts, correct_prompts)
+
+        # Similarly to the single case, no test for chat template+tokenization as two separate steps versus as a single step
+
+        # Checking the output dict keys
+        out_dict = processor.apply_chat_template(
+            batched_messages,
+            add_generation_prompt=True,
+            tokenize=True,
+            return_dict=True,
+            padding=True,
+        )
+        self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
+
+        # Verify image inputs are included in the output dict
+        batched_messages[0][0]["content"][1].update(
+            {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
+        )
+        batched_messages[1][0]["content"][1].update(
+            {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}
+        )
+        out_dict = processor.apply_chat_template(
+            batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, padding=True
+        )
+        self.assertTrue(self.images_input_name in out_dict)
+        self.assertEqual(len(out_dict["input_ids"]), 2)  # Batch size for text
+        self.assertEqual(len(out_dict["attention_mask"]), 2)  # Batch size for attention mask
+        self.assertEqual(len(out_dict[self.images_input_name]), 2)  # Batch size for images
+
+        # Test 2: Two images per message batch with different prompts
+        batched_messages = [
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "Compare this image"},
+                        {"type": "image"},
+                        {"type": "text", "text": "with this image"},
+                        {"type": "image"},
+                    ],
+                },
+            ],
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "image"},
+                        {"type": "text", "text": "Describe how the previous image compares to the following"},
+                        {"type": "image"},
+                    ],
+                },
+            ],
+        ]
+
+        correct_prompts = [
+            "<|User|>: Compare this image\n\nwith this image\n\n\n<|Assistant|>:",
+            "<|User|>: \nDescribe how the previous image compares to the following\n\n\n<|Assistant|>:",
+        ]
+        formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
+        self.assertEqual(formatted_prompts, correct_prompts)
+
+        # Test 3: Multi-turn conversations with multiple images
+        batched_messages = [
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "Compare this image"},
+                        {"type": "image"},
+                        {"type": "text", "text": "with this image"},
+                        {"type": "image"},
+                    ],
+                },
+                {
+                    "role": "assistant",
+                    "content": [
+                        {"type": "text", "text": "The first image is an equation, the second is a pie chart."},
+                    ],
+                },
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "image"},
+                        {
+                            "type": "text",
+                            "text": "What about this third image? To which of the previous to is it more similar?",
+                        },
+                    ],
+                },
+            ],
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "image"},
+                        {"type": "text", "text": "Describe how the previous image compares to the following"},
+                        {"type": "image"},
+                    ],
+                },
+                {
+                    "role": "assistant",
+                    "content": [
+                        {"type": "text", "text": "The first image is a formula, the second is a plot."},
+                    ],
+                },
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "Which of them is closer to the following?"},
+                        {"type": "image"},
+                    ],
+                },
+            ],
+        ]
+
+        correct_prompts = [
+            "<|User|>: Compare this image\n\nwith this image\n\n\n<|Assistant|>: The first image is an equation, the second is a pie chart.<|end▁of▁sentence|><|User|>: \nWhat about this third image? To which of the previous to is it more similar?\n\n<|Assistant|>:",
+            "<|User|>: \nDescribe how the previous image compares to the following\n\n\n<|Assistant|>: The first image is a formula, the second is a plot.<|end▁of▁sentence|><|User|>: Which of them is closer to the following?\n\n\n<|Assistant|>:",
+        ]
+        formatted_prompts = processor.apply_chat_template(batched_messages, add_generation_prompt=True)
+        self.assertEqual(formatted_prompts, correct_prompts)
+
+    def test_chat_template_accepts_processing_kwargs(self):
+        """Tests that the chat template correctly handles additional processing arguments."""
+        # Get processor and skip if it doesn't have a chat template
+        processor = self.get_processor()
+        if processor.chat_template is None:
+            self.skipTest("Processor has no chat template")
+
+        # Create a simple text message for testing
+        messages = [
+            [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "text", "text": "What is shown in this image?"},
+                    ],
+                },
+            ]
+        ]
+
+        # Test 1: Padding to max_length
+        # PS: we have to override the parent max_length of 50 to 80 because the output is already 51 tokens
+        formatted_prompt_tokenized = processor.apply_chat_template(
+            messages,
+            add_generation_prompt=True,
+            tokenize=True,
+            padding="max_length",
+            max_length=80,
+        )
+        self.assertEqual(len(formatted_prompt_tokenized[0]), 80)
+
+        # Test 2: Truncation
+        # Verify that the output is truncated to exactly 5 tokens
+        formatted_prompt_tokenized = processor.apply_chat_template(
+            messages,
+            add_generation_prompt=True,
+            tokenize=True,
+            truncation=True,
+            max_length=5,
+        )
+        self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
+
+        # Test 3: Image processing kwargs
+        # Add an image and test image processing parameters
+        messages[0][0]["content"].append(
+            {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
+        )
+        # Process with image rescaling and verify the pixel values are negative
+        out_dict = processor.apply_chat_template(
+            messages,
+            add_generation_prompt=True,
+            tokenize=True,
+            return_dict=True,
+            do_rescale=True,
+            rescale_factor=-1,
+            return_tensors="np",
+        )
+        self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
+
+    def test_processor_postprocess(self):
+        processor_components = self.prepare_components()
+        processor = self.processor_class(**processor_components)
+
+        input_str = "lower newer"
+        orig_image_input = self.prepare_image_inputs()
+        orig_image = np.array(orig_image_input).transpose(2, 0, 1)
+
+        inputs = processor(text=input_str, images=orig_image, do_resize=False, return_tensors="np")
+        normalized_image_input = inputs.pixel_values
+        unnormalized_images = processor.postprocess(normalized_image_input, return_tensors="np")["pixel_values"]
+
+        # For an image where pixels go from 0 to 255 the diff can be 1 due to some numerical precision errors when scaling and unscaling
+        self.assertTrue(np.abs(orig_image - unnormalized_images).max() >= 1)
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 85178b663e4..95dac3d6b76 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -156,6 +156,7 @@ IGNORE_NON_TESTED = (
         "Llama4VisionModel",  # Building part of bigger (tested) model. # TODO: add tests
         "Emu3VQVAE",  # Building part of bigger (tested) model
         "Emu3TextModel",  # Building part of bigger (tested) model
+        "JanusVisionModel",  # Building part of bigger (tested) model
         "TimesFmModel",  # Building part of bigger (tested) model
     ]
 )
@@ -356,6 +357,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
     "MoshiForConditionalGeneration",  # no auto class for speech-to-speech
     "Emu3VQVAE",  # no autoclass for VQ-VAE models
     "Emu3TextModel",  # Building part of bigger (tested) model
+    "JanusVQVAE",  # no autoclass for VQ-VAE models
+    "JanusVisionModel",  # Building part of bigger (tested) model
     "Qwen2_5OmniTalkerForConditionalGeneration",  # Building part of a bigger model
     "Qwen2_5OmniTalkerModel",  # Building part of a bigger model
     "Qwen2_5OmniThinkerForConditionalGeneration",  # Building part of a bigger model