transformers/docs/source/en/model_doc/janus.md
Yaswanth Gali a2ef3cf537
Add Janus model (#36053)
* Iterative generation using input embeds

* Add Janus model

* discard changes

* Janus imports

* Refactor config and processor

* Added Vision tower of Janus

* Import Janus Image processor

* Vision tower fixes

* Refactor code

* Added VQ Model

* Complete model integration

* temp conversion script

* processor refactor

* Adding files to facilitate pulling

* Fixes after debugging

* Skip test for these models

* Add Janus Model

* discard changes

* Janus imports

* Refactor config and processor

* Added Vision tower of Janus

* Import Janus Image processor

* Vision tower fixes

* Refactor code

* Added VQ Model

* Complete model integration

* temp conversion script

* processor refactor

* Adding files to facilitate pulling

* Fixes after debugging

* Refactor to Text config

*  Added generate function

* Saving intermediate convert file. Still need to read configs from the hub and convert them to our format.

* Adding version that reads from the JSON files. Still have to tweak some parameters manually.

* relative imports

* Initial tests

* Refactor image processor

* Seemingly working version of the conversion script, will need to test further.

* Adding command message

* Fixing conflicting JanusTextConfig class

* Incorporating some of the discussed changes.

* Small fix to create dir.

* Removing system from JINJA template

* Adding draft processor tests

* style fixes

* Minor fixes and enhancement

* added generation config

* Initial tests

* Small modifications, tests are now passing.

* Small changes I noticed while reading code.

* more fixes

* Added JanusModel class

* Small merge adaptations

* Small merge adaptations

* Image processing tests passing

* More tests and fixes

* Convert script updated and refactored

* Tests and cleanup

* make style

* Postprocessing for image generation

* generate refactor

* fixes

* - Passing tests that write a part of the model to cpu (e.g. test_cpu_offload)
- Passing tests of dispatching SDPA
- Only gradient checkpointing tests are left.

* Removing temporary code

* Changes

* Writing change to modular

* Added JanusVisionModel. SDPA dispatch tests pass more robustly. Gradient checkpoint tests are next

* Gradient checkpoint tests passing

* Removing debug code

* Major generate refactor 😮‍💨

* Temp changes for testing

* Green quality CI

* 2 out of 4 integration tests passing

* breadcrumbs

* Usage Examples

* Regenerate modeling after merge

* dirty code

* JanusIntegrationTest are passing

* breadcrumbs

* happy CI

* fixes

* Changing template

* nits

* Text generation logits matching original codebase at 100% precision

* Remove ./tmp from git tracking

* Remove ./tmp from git tracking

* Checkpointing changes after reviewing

* Fixing code in docstrings

* CHanging comments and small bug in convert file

* Fixing bug in image_token_id for 7B version

* Removing line that was added by both of us

* Pushing changes after discussion. Only one left is to change the key mapping for convert file.

* Updating module file

* New convert file using dict. Tested that it is equivalent to the old one by:
- comparing keys in a script
- comparing checksums of the output files between version generated with the current convert script and those generated with the old script. This is a more reliable test.

* revert changes

* mistake

* consistency change for CI

* make style

* doc fixes

* more fixes

* experimenting with masking out pad token

* checkpoint

* Batched generation with multi-images working for 1B models. Will test 7B next.

* Device fix.

* Writing changes to modular, previous ones were written to modeling just for quick testing.

* Using passed processor attention mask (only in modeling for now)

* Matching performance done in the non-standard way

* Working version of batched generation. Will change how some args are passed to make it more similar to language case

* More compliant version of the code

* Removed duplicated `_prepare_4d_causal_attention_mask_with_cache_position`

* Updating modular file, making masked filling with paddings more efficient

* Slightly more efficient version

* Modifying JanusVisionModel to be a wrapper

* Fixing test to comply with new names

* Modular overhaul

* More refactoring

* - Changing JanusVisionModel back
- Changing forward pass
- Adding boi token to the comparison

* - Removing whole context model_ids
- Using inherited implementation of prepare_inputs_for_generation

* Moving the way boi token is passed to the model

* Fixing sdpa test

* Minor changes

* testing changes

* Minor fix

* - Adding postprocessing test
- checking values of generated image on integration test

* changes

* Removing pooled attention vision module, fixing convert script as a consequence

* More changes

* Fixes

* Draft after merge

* Bug fixes

* More bug fix

* Fixing docs

* Nits

* Refactor return dict

* Moving image post processing test to main processor post process

* Passing guidance_scale as kwarg

* make style

* 🔥 refactor

* make style

* Update and green CI

* Nits and tests update

* up

* Added MID block

* fix

* Dead code

* update testcase

* update

* model_id change

* init_weight changes

---------

Co-authored-by: hsilva664 <metallic-silver@hotmail.com>
2025-04-17 09:18:51 +02:00

8.5 KiB
Raw Blame History

Janus

Overview

The Janus Model was originally proposed in Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation by DeepSeek AI team and later refined in Janus-Pro: Unified Multimodal Understanding and Generation with Data and Model Scaling. Janus is a vision-language model that can generate both image and text output, it can also take both images and text as input.

Note

The model doesn't generate both images and text in an interleaved format. The user has to pass a parameter indicating whether to generate text or image.

The abstract from the original paper is the following:

In this paper, we introduce Janus, an autoregressive framework that unifies multimodal understanding and generation. Prior research often relies on a single visual encoder for both tasks, such as Chameleon. However, due to the differing levels of information granularity required by multimodal understanding and generation, this approach can lead to suboptimal performance, particularly in multimodal understanding. To address this issue, we decouple visual encoding into separate pathways, while still leveraging a single, unified transformer architecture for processing. The decoupling not only alleviates the conflict between the visual encoder's roles in understanding and generation, but also enhances the framework's flexibility. For instance, both the multimodal understanding and generation components can independently select their most suitable encoding methods. Experiments show that Janus surpasses previous unified model and matches or exceeds the performance of task-specific models. The simplicity, high flexibility, and effectiveness of Janus make it a strong candidate for next-generation unified multimodal models.

The abstract from the aforementioned Janus-Pro paper, released afterwards, is the following:

In this work, we introduce Janus-Pro, an advanced version of the previous work Janus. Specifically, Janus-Pro incorporates (1) an optimized training strate (2) expanded training data, and (3) scaling to larger model size. With these improvements, Janus-Pro achieves significant advancements in both multimodal understanding and text-to-image instruction-following capabilities, while also enhancing the stability of text-to-image generation. We hope this work will inspire further exploration in the field. Code and models are publicly available.

This model was contributed by Yaswanth Gali and Hugo Silva. The original code can be found here.

Usage Example

Single image inference

Here is the example of visual understanding with a single image.

Note

Note that the model has been trained with a specific prompt format for chatting. Use processor.apply_chat_template(my_conversation_dict) to correctly format your prompts.

import torch  
from PIL import Image  
import requests  

from transformers import JanusForConditionalGeneration, JanusProcessor  

model_id = "deepseek-community/Janus-Pro-1B"
# Prepare Input for generation.
messages = [
    {
        "role": "user",
        "content": [
            {'type':'image', 'url': 'http://images.cocodataset.org/val2017/000000039769.jpg'},
            {'type':"text", "text":"What do you see in this image?."}
        ]
    },
]

# Set generation mode to `text` to perform text generation.
processor = JanusProcessor.from_pretrained(model_id)
model = JanusForConditionalGeneration.from_pretrained(model_id,     
        torch_dtype=torch.bfloat16,
        device_map="auto")

inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    generation_mode="text",
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
).to(model.device, dtype=torch.bfloat16)

output = model.generate(**inputs, max_new_tokens=40,generation_mode='text',do_sample=True)
text = processor.decode(output[0], skip_special_tokens=True)
print(text)

Multi image inference

Janus can perform inference with multiple images as input, where images can belong to the same prompt or different prompts in batched inference, where the model processes many conversations in parallel. Here is how you can do it:

import torch
from PIL import Image
import requests

from transformers import JanusForConditionalGeneration, JanusProcessor

model_id = "deepseek-community/Janus-Pro-1B"

image_urls = [
    "http://images.cocodataset.org/val2017/000000039769.jpg",
    "https://www.ilankelman.org/stopsigns/australia.jpg",
    "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
]

messages = [
    [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Whats the difference between"},
                {"type": "image", "url": image_urls[0]},
                {"type": "text", "text": " and "},
                {"type": "image", "url": image_urls[1]}
            ]
        }
    ],
    [
        {
            "role": "user",
            "content": [
                {"type": "image", "url": image_urls[2]},
                {"type": "text", "text": "What do you see in this image?"}
            ]
        }
    ]
]

# Load model and processor
processor = JanusProcessor.from_pretrained(model_id)
model = JanusForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="auto"
)

inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    generation_mode="text",
    tokenize=True,
    padding=True,
    return_dict=True,
    return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)

# Generate response
output = model.generate(**inputs, max_new_tokens=40, generation_mode='text', do_sample=False)
text = processor.batch_decode(output, skip_special_tokens=True)
print(text)

Text to Image generation

Janus can also generate images given a prompt.

import torch
from transformers import JanusForConditionalGeneration, JanusProcessor

# Set generation mode to `image` to prepare inputs for image generation..

model_id = "deepseek-community/Janus-Pro-1B"
processor = JanusProcessor.from_pretrained(model_id)
model = JanusForConditionalGeneration.from_pretrained(model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto")

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "A dog running under the rain."},
        ],
     }
]

prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt,generation_mode="image",return_tensors="pt").to(model.device, dtype=torch.bfloat16)

# Set num_return_sequence parameter to generate multiple images per prompt.
model.generation_config.num_return_sequences = 2
outputs = model.generate(**inputs,
                         generation_mode="image",
                         do_sample=True,
                         use_cache=True,
                         )
# Perform post-processing on the generated token ids.
decoded_image = model.decode_image_tokens(outputs)
images = processor.postprocess(list(decoded_image.float()),return_tensors="PIL.Image.Image")
# Save the image
for i, image in enumerate(images['pixel_values']):
    image.save(f"result{i}.png")

JanusConfig

autodoc JanusConfig

JanusVisionConfig

autodoc JanusVisionConfig

JanusVQVAEConfig

autodoc JanusVQVAEConfig

JanusProcessor

autodoc JanusProcessor

JanusImageProcessor

autodoc JanusImageProcessor

JanusVisionModel

autodoc JanusVisionModel - forward

JanusVQVAE

autodoc JanusVQVAE - forward

JanusModel

autodoc JanusModel - forward

JanusForConditionalGeneration

autodoc JanusForConditionalGeneration - forward