This commit is contained in:
Yaswanth Gali 2025-07-03 00:02:35 +05:30 committed by GitHub
commit cb0e8edecb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 4030 additions and 0 deletions

View File

@ -45,6 +45,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("audio-spectrogram-transformer", "ASTConfig"), ("audio-spectrogram-transformer", "ASTConfig"),
("autoformer", "AutoformerConfig"), ("autoformer", "AutoformerConfig"),
("aya_vision", "AyaVisionConfig"), ("aya_vision", "AyaVisionConfig"),
("bagel", "BagelConfig"),
("bamba", "BambaConfig"), ("bamba", "BambaConfig"),
("bark", "BarkConfig"), ("bark", "BarkConfig"),
("bart", "BartConfig"), ("bart", "BartConfig"),
@ -415,6 +416,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("audio-spectrogram-transformer", "Audio Spectrogram Transformer"), ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"),
("autoformer", "Autoformer"), ("autoformer", "Autoformer"),
("aya_vision", "AyaVision"), ("aya_vision", "AyaVision"),
("bagel", "bagel"),
("bamba", "Bamba"), ("bamba", "Bamba"),
("bark", "Bark"), ("bark", "Bark"),
("bart", "BART"), ("bart", "BART"),

View File

@ -41,6 +41,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("audio-spectrogram-transformer", "ASTModel"), ("audio-spectrogram-transformer", "ASTModel"),
("autoformer", "AutoformerModel"), ("autoformer", "AutoformerModel"),
("aya_vision", "AyaVisionModel"), ("aya_vision", "AyaVisionModel"),
("bagel", "BagelModel"),
("bamba", "BambaModel"), ("bamba", "BambaModel"),
("bark", "BarkModel"), ("bark", "BarkModel"),
("bart", "BartModel"), ("bart", "BartModel"),
@ -375,6 +376,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[ [
# Model for pre-training mapping # Model for pre-training mapping
("albert", "AlbertForPreTraining"), ("albert", "AlbertForPreTraining"),
("bagel", "BagelForConditionalGeneration"),
("bart", "BartForConditionalGeneration"), ("bart", "BartForConditionalGeneration"),
("bert", "BertForPreTraining"), ("bert", "BertForPreTraining"),
("big_bird", "BigBirdForPreTraining"), ("big_bird", "BigBirdForPreTraining"),
@ -908,6 +910,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
[ [
("aria", "AriaForConditionalGeneration"), ("aria", "AriaForConditionalGeneration"),
("aya_vision", "AyaVisionForConditionalGeneration"), ("aya_vision", "AyaVisionForConditionalGeneration"),
("bagel", "BagelForConditionalGeneration"),
("blip", "BlipForConditionalGeneration"), ("blip", "BlipForConditionalGeneration"),
("blip-2", "Blip2ForConditionalGeneration"), ("blip-2", "Blip2ForConditionalGeneration"),
("chameleon", "ChameleonForConditionalGeneration"), ("chameleon", "ChameleonForConditionalGeneration"),

View File

@ -49,6 +49,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("altclip", "AltCLIPProcessor"), ("altclip", "AltCLIPProcessor"),
("aria", "AriaProcessor"), ("aria", "AriaProcessor"),
("aya_vision", "AyaVisionProcessor"), ("aya_vision", "AyaVisionProcessor"),
("bagel", "BagelProcessor"),
("bark", "BarkProcessor"), ("bark", "BarkProcessor"),
("blip", "BlipProcessor"), ("blip", "BlipProcessor"),
("blip-2", "Blip2Processor"), ("blip-2", "Blip2Processor"),

View File

@ -0,0 +1,27 @@
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_janus import *
from .modeling_janus import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,190 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/bagel/modular_bagel.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_bagel.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 ByteDance and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...modeling_utils import logging
from ..auto import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
class BagelVQVAEConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`BambaModel`]. It is used to instantiate a
BambaModel model according to the specified arguments, defining the model architecture. Instantiating a configuration
with defaults taken from [ibm-fms/Bamba-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/Bamba-9.8b-2.2T-hf).
The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
model_type = "bagel_vqvae"
base_config_key = "vq_config"
def __init__(
self,
double_latent: bool = False,
latent_channels: int = 16,
num_patches: int = 32,
latent_patch_size=2,
in_channels: int = 3,
out_channels: int = 3,
base_channels: int = 128,
channel_multiplier: list[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
dropout: float = 0.0,
initializer_range=0.02,
scale_factor=0.2611,
shift_factor=0.1159,
**kwargs,
):
super().__init__(**kwargs)
self.double_latent = double_latent
self.latent_channels = latent_channels
self.in_channels = in_channels
self.base_channels = base_channels
self.channel_multiplier = channel_multiplier
self.num_res_blocks = num_res_blocks
self.dropout = dropout
self.initializer_range = initializer_range
self.num_patches = num_patches
self.out_channels = out_channels
self.scale_factor = scale_factor
self.shift_factor = shift_factor
self.latent_patch_size = latent_patch_size
llm_config = {
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"hidden_act": "silu",
"hidden_size": 3584,
"initializer_range": 0.02,
"intermediate_size": 18944,
"max_position_embeddings": 32768,
"max_window_layers": 28,
"model_type": "qwen2",
"num_attention_heads": 28,
"num_hidden_layers": 28,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-06,
"rope_theta": 1000000.0,
"sliding_window": 131072,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.43.1",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 152064,
}
vit_config = {
"hidden_size": 1152,
"image_size": 980,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 26,
"patch_size": 14,
"vision_use_head": False,
}
class BagelConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`BambaModel`]. It is used to instantiate a
BambaModel model according to the specified arguments, defining the model architecture. Instantiating a configuration
with defaults taken from [ibm-fms/Bamba-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/Bamba-9.8b-2.2T-hf).
The BambaModel is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
"""
model_type = "bagel"
sub_configs = {
"text_config": AutoConfig,
"vision_config": AutoConfig,
"vq_config": BagelVQVAEConfig,
}
def __init__(
self,
text_config=None,
vision_config=None,
vq_config=None,
**kwargs,
):
if isinstance(text_config, dict):
text_config["model_type"] = text_config.get("model_type", "qwen2")
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("`text_config` is None. Initializing with default values")
self.text_config = CONFIG_MAPPING["qwen2"](**llm_config)
elif isinstance(text_config, PretrainedConfig):
self.text_config = text_config
else:
raise ValueError(
f"Invalid type for `text_config`. Must be either `dict` or `Qwen2Config`."
f" Type found: {type(text_config)}"
)
if isinstance(text_config, dict):
vision_config["model_type"] = vision_config.get("model_type", "siglip_vision_model")
self.vision_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("`vision_config` is None. Initializing with default values")
self.vision_config = CONFIG_MAPPING["siglip_vision_model"](**vit_config)
elif isinstance(vision_config, PretrainedConfig):
self.vision_config = vision_config
else:
raise ValueError(
f"Invalid type for `vision_config`. Must be either `dict` or `SiglipVisionConfig`."
f" Type found: {type(vision_config)}"
)
if vq_config is None:
logger.info("`vq_config` is None. Initializing with default JanusVQVAEConfig values")
self.vq_config = BagelVQVAEConfig()
elif isinstance(vq_config, dict):
self.vq_config = BagelVQVAEConfig(**vq_config)
elif isinstance(vq_config, BagelVQVAEConfig):
self.vq_config = vq_config
else:
raise ValueError(
f"Invalid type for `vq_config`. Must be either `dict` or `JanusVQVAEConfig`."
f" Type found: {type(vq_config)}"
)
self.vit_max_num_patch_per_side = 70
self.max_latent_size = 64
self.timestep_shift = 1.0
super().__init__(**kwargs)
__all__ = ["BagelVQVAEConfig", "BagelConfig"]

View File

@ -0,0 +1,430 @@
# coding=utf-8
# Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of run command (run from root):
python src/transformers/models/janus/convert_janus_weights_to_hf.py --repo_id deepseek-ai/Janus-Pro-1B --local_dir tmp/hub_code_in --output_dir tmp/hub_code_out --safe_serialization
Using provided local directory: tmp/hub_code_in
"""
import argparse
import gc
import json
import os
import re
from typing import Optional
import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import (
AutoTokenizer,
BagelConfig,
JanusForConditionalGeneration,
JanusVisionConfig,
JanusVQVAEConfig,
LlamaConfig,
)
from transformers.models.janus.image_processing_janus import JanusImageProcessor
from transformers.models.janus.processing_janus import JanusProcessor
# Mappings
MAPPINGS = {}
def convert_old_keys_to_new_keys(state_dict):
keys_as_text = "\n".join(state_dict.keys())
new_keys_as_text = keys_as_text
for old, repl in MAPPINGS.items():
if repl is None:
new_keys_as_text = re.sub(old, "", new_keys_as_text)
else:
new_keys_as_text = re.sub(old, repl, new_keys_as_text)
output_dict = dict(zip(keys_as_text.split("\n"), new_keys_as_text.split("\n")))
return output_dict
def split_tensor(tensor, key):
"""Splits a merged tensor (qkv or kv) into separate tensors and creates keys for each part."""
if "qkv" in key:
prefix_to_replace = "qkv"
num_splits = 3
new_keys = ["q_proj", "k_proj", "v_proj"]
else:
raise ValueError(f"Unrecognized tensor type in key: {key}")
split_size = tensor.shape[0] // num_splits
tensors = torch.split(tensor, split_size, dim=0)
return {key.replace(prefix_to_replace, new_keys[i]): tensors[i] for i in range(num_splits)}
def convert_state_dict_to_hf(state_dict):
"""Convert state dict keys to HF format."""
conversion_dict = convert_old_keys_to_new_keys(state_dict)
converted_state_dict = {}
for old_key, new_key in conversion_dict.items():
if new_key:
if "qkv" in new_key or "kv" in new_key: # Detect merged attention keys and split them.
qkv_split_dict = split_tensor(state_dict[old_key], new_key)
converted_state_dict.update(qkv_split_dict)
else:
converted_state_dict[new_key] = state_dict[old_key]
# Embeddings will not have initial dimension
pos_embed_key = "model.vision_model.embeddings.position_embedding.weight"
converted_state_dict[pos_embed_key] = converted_state_dict[pos_embed_key].squeeze(0)
return converted_state_dict
def ensure_model_downloaded(
repo_id: Optional[str] = None, revision: Optional[str] = None, local_dir: Optional[str] = None
) -> str:
"""
Ensures model files are downloaded locally, downloads them if not.
Returns path to local files.
Args:
repo_id: The Hugging Face model repo ID (required if local_dir not provided)
revision: Optional git revision to use
local_dir: Optional local directory path where model files should be stored/found
"""
if local_dir is not None:
if os.path.exists(local_dir):
print(f"Using provided local directory: {local_dir}")
else:
# Create the local directory if it doesn't exist
os.makedirs(local_dir, exist_ok=True)
print(f"Created local directory: {local_dir}")
if repo_id is None:
raise ValueError("Either repo_id or local_dir must be provided")
print(f"Ensuring {repo_id} (revision: {revision or 'latest'}) is downloaded...")
try:
# First try to find files locally
download_dir = snapshot_download(repo_id, revision=revision, local_files_only=True, local_dir=local_dir)
print(f"Found model files locally at {download_dir}")
return download_dir
except Exception:
# If files not found locally, download them
print(f"Downloading model files for {repo_id}...")
download_dir = snapshot_download(repo_id, revision=revision, local_files_only=False, local_dir=local_dir)
print(f"Downloaded model files to {download_dir}")
return download_dir
def load_model_state_dict(input_path: str) -> dict:
"""
Load model state dict, handling both single and sharded files.
"""
index_path = os.path.join(input_path, "pytorch_model.bin.index.json")
single_file_path = os.path.join(input_path, "pytorch_model.bin")
# Check if we have a sharded model
if os.path.exists(index_path):
print("Loading sharded model...")
state_dict = {}
with open(index_path, "r") as f:
index = json.load(f)
# Get unique shard files and load each one only once
unique_shard_files = sorted(set(index["weight_map"].values()))
for shard_file in unique_shard_files:
print(f"Loading shard {shard_file}...")
shard_path = os.path.join(input_path, shard_file)
shard_dict = torch.load(shard_path, map_location="cpu")
state_dict.update(shard_dict)
return state_dict
# Single file model
elif os.path.exists(single_file_path):
print("Loading single file model...")
return torch.load(single_file_path, map_location="cpu")
else:
raise ValueError(f"No model files found in {input_path}")
def convert_model(
repo_id=None,
local_dir=None,
text_model_id=None,
output_dir=None,
output_hub_path=None,
safe_serialization=True,
revision=None,
):
"""Convert and save the model weights, processor, and configuration."""
if output_dir is None and output_hub_path is None:
raise ValueError("At least one of output_dir or output_hub_path must be specified")
if repo_id is None and local_dir is None:
raise ValueError("Either repo_id or local_dir must be specified")
# Create output directory if specified
if output_dir:
os.makedirs(output_dir, exist_ok=True)
print(f"Created/verified output directory: {output_dir}")
torch.set_default_dtype(torch.float16)
# Download or locate model files
input_path = ensure_model_downloaded(repo_id=repo_id, revision=revision, local_dir=local_dir)
# Load configuration files
required_files = ["config.json", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json"]
missing_files = [f for f in required_files if not os.path.exists(os.path.join(input_path, f))]
if missing_files:
raise ValueError(
f"The following required configuration files are missing from {input_path}: {', '.join(missing_files)}. "
"Please ensure you have downloaded all necessary model files."
)
with open(os.path.join(input_path, "config.json"), "r") as f:
config_data = json.load(f)
with open(os.path.join(input_path, "preprocessor_config.json"), "r") as f:
preprocessor_config = json.load(f)
with open(os.path.join(input_path, "special_tokens_map.json"), "r") as f:
special_tokens_map = json.load(f)
with open(os.path.join(input_path, "tokenizer_config.json"), "r") as f:
tokenizer_config = json.load(f)
# Create tokenizer directly from tokenizer.json if it exists
tokenizer_json_path = os.path.join(input_path, "tokenizer.json")
special_image_tokens = {
"image_token": "<image_placeholder>",
"boi_token": "<begin_of_image>",
"eoi_token": "<end_of_image>",
}
if os.path.exists(tokenizer_json_path) and not text_model_id:
tokenizer = AutoTokenizer.from_pretrained(
input_path, # This will load tokenizer.json directly
model_max_length=tokenizer_config["model_max_length"],
extra_special_tokens=special_image_tokens,
)
else:
# Fallback to creating from text_model_id with special tokens
tokenizer = AutoTokenizer.from_pretrained(
text_model_id,
bos_token=special_tokens_map["bos_token"],
eos_token=special_tokens_map["eos_token"],
pad_token=special_tokens_map["pad_token"],
additional_special_tokens=special_tokens_map["additional_special_tokens"],
model_max_length=tokenizer_config["model_max_length"],
extra_special_tokens=special_image_tokens,
)
# Create image processor from config
image_processor_kwargs = {}
for key in ["do_normalize", "image_mean", "image_std", "min_size", "rescale_factor"]:
if key in preprocessor_config:
image_processor_kwargs[key] = preprocessor_config[key]
if "image_size" in preprocessor_config:
image_processor_kwargs["size"] = {
"height": preprocessor_config["image_size"],
"width": preprocessor_config["image_size"],
}
image_processor = JanusImageProcessor(**image_processor_kwargs)
# Create processor with chat template
processor = JanusProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
chat_template="",
use_default_system_prompt=True,
)
if output_dir:
print(f"Saving processor to {output_dir}...")
processor.save_pretrained(output_dir)
if output_hub_path:
print(f"Pushing processor to hub at {output_hub_path}...")
processor.push_to_hub(output_hub_path)
# Create model configurations
text_config_kwargs = {}
for key in [
"vocab_size",
"hidden_size",
"intermediate_size",
"num_hidden_layers",
"num_attention_heads",
"num_key_value_heads",
"hidden_act",
"max_position_embeddings",
"torch_dtype",
]:
if key in config_data["language_config"]:
text_config_kwargs[key] = config_data["language_config"][key]
# Add token IDs from tokenizer
text_config_kwargs.update(
{
"pad_token_id": tokenizer.pad_token_id,
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
)
text_config = LlamaConfig(**text_config_kwargs)
# Create vision config
vision_config_kwargs = {}
if "image_size" in config_data["vision_config"]["params"]:
vision_config_kwargs["image_size"] = config_data["vision_config"]["params"]["image_size"]
# Add aligner params if present
if "aligner_config" in config_data and "params" in config_data["aligner_config"]:
if "n_embed" in config_data["aligner_config"]["params"]:
vision_config_kwargs["projection_dim"] = config_data["aligner_config"]["params"]["n_embed"]
if "depth" in config_data["aligner_config"]["params"]:
vision_config_kwargs["depth"] = config_data["aligner_config"]["params"]["depth"]
vision_config = JanusVisionConfig(**vision_config_kwargs)
vq_config = JanusVQVAEConfig(
embed_dim=config_data["gen_vision_config"]["params"]["n_embed"],
num_embeddings=config_data["gen_vision_config"]["params"]["image_token_size"],
projection_dim=config_data["gen_aligner_config"]["params"]["n_embed"],
depth=config_data["gen_aligner_config"]["params"]["depth"],
image_token_embed_dim=config_data["gen_head_config"]["params"]["image_token_embed"],
)
# Create the main config
config = BagelConfig(
text_config=text_config,
vision_config=vision_config,
vq_config=vq_config,
image_token_id=tokenizer.vocab.get("<image_placeholder>"),
)
# Save the config
if output_dir:
config.save_pretrained(output_dir)
if output_hub_path:
config.push_to_hub(output_hub_path)
# Initialize model with empty weights
print("Creating empty model...")
with init_empty_weights():
model = JanusForConditionalGeneration(config)
model.generation_config.temperature = 1
model.generation_config.guidance_scale = 5
model.generation_config.pad_token_id = tokenizer.vocab.get("<\uff5c\u2581pad\u2581\uff5c>")
model.generation_config.generation_kwargs["boi_token_id"] = tokenizer.vocab.get("<begin_of_image>")
# Load and convert state dict
print("Loading state dict...")
state_dict = load_model_state_dict(input_path)
state_dict = convert_state_dict_to_hf(state_dict)
# Load converted state dict
print("Loading converted weights into model...")
model.load_state_dict(state_dict, strict=True, assign=True)
# Tie weights before any device mapping
print("Tying weights...")
model.tie_weights()
# Save the model
if output_dir:
print(f"Saving model to {output_dir}...")
model.save_pretrained(output_dir, safe_serialization=safe_serialization)
if output_hub_path:
print(f"Pushing model to hub at {output_hub_path}...")
model.push_to_hub(output_hub_path, safe_serialization=safe_serialization)
del state_dict, model
gc.collect()
# Validate the saved model if saved locally
if output_dir:
print("Reloading the local model to check if it's saved correctly...")
# TODO: warning about weights not being tied is raised here regardless of model.tie_weights() above
JanusForConditionalGeneration.from_pretrained(output_dir, device_map="auto")
print("Local model reloaded successfully.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo_id",
help="HuggingFace Hub repo ID for the model",
default=None,
)
parser.add_argument(
"--local_dir",
help="Local directory containing the model files",
default=None,
)
parser.add_argument(
"--revision",
help="Specific revision to download from the Hub",
default=None,
)
parser.add_argument(
"--output_dir",
help="Location to write HF model locally",
default=None,
)
parser.add_argument(
"--output_hub_path",
help="Repository ID to push model to hub (e.g. 'username/model-name')",
default=None,
)
parser.add_argument(
"--text_model_id",
help="Hub ID of the text model to get tokenizer from. Optional if tokenizer.json exists in the model directory.",
required=False,
)
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to save using safetensors",
)
args = parser.parse_args()
if args.output_dir is None and args.output_hub_path is None:
raise ValueError("At least one of --output_dir or --output_hub_path must be specified")
if args.repo_id is None and args.local_dir is None:
raise ValueError("Either --repo_id or --local_dir must be specified")
convert_model(
repo_id=args.repo_id,
local_dir=args.local_dir,
text_model_id=args.text_model_id,
output_dir=args.output_dir,
output_hub_path=args.output_hub_path,
safe_serialization=args.safe_serialization,
revision=args.revision,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,371 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Bagel."""
from typing import Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
from ...image_utils import (
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import (
TensorType,
filter_out_non_signature_kwargs,
is_vision_available,
logging,
)
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
class BagelImageProcessor(BaseImageProcessor):
r"""
Constructs a Bagel image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Optional[dict[str, int]] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_convert_rgb: Optional[bool] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 384, "width": 384}
size = get_size_dict(size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_convert_rgb = do_convert_rgb
self.background_color = tuple([int(x * 255) for x in self.image_mean])
def resize(
self,
image: np.ndarray,
size: dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize and pad an image to a square based on the longest edge in `size`.
Args:
image (`np.ndarray`):
Image to resize.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `None`: will be inferred from input
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, input_data_format)
max_size = max(height, width)
size = get_size_dict(size, default_to_square=True)
if size["height"] != size["width"]:
raise ValueError(
f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
)
size = size["height"]
delta = size / max_size
# Largest side becomes `size` and the other side is scaled according to the aspect ratio.
output_size_nonpadded = [int(height * delta), int(width * delta)]
image = resize(
image,
size=output_size_nonpadded,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
return_numpy=True,
**kwargs,
)
# Expand and pad the images to obtain a square image of dimensions `size x size`
image = self.pad_to_square(
image=image,
input_data_format=input_data_format,
)
return image
@filter_out_non_signature_kwargs()
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
do_convert_rgb: Optional[bool] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The shortest edge of the image is resized to
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_resize=do_resize,
size=size,
resample=resample,
)
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if do_rescale and is_scaled_image(images[0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
all_images = []
for image in images:
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
all_images.append(image)
data = {"pixel_values": all_images}
return BatchFeature(data=data, tensor_type=return_tensors)
def pad_to_square(
self,
image: np.ndarray,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.array:
"""
Pads an image to a square based on the longest edge.
Args:
image (`np.ndarray`):
The image to pad.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
If unset, will use same as the input image.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format for the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The padded image.
"""
height, width = get_image_size(image, input_data_format)
num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
if height == width:
image = (
to_channel_dimension_format(image, data_format, input_data_format)
if data_format is not None
else image
)
return image
max_dim = max(height, width)
if input_data_format == ChannelDimension.FIRST:
result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
for i, color in enumerate(self.background_color):
result[i, :, :] = color
if width > height:
start = (max_dim - height) // 2
result[:, start : start + height, :] = image
else:
start = (max_dim - width) // 2
result[:, :, start : start + width] = image
else:
result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
for i, color in enumerate(self.background_color):
result[:, :, i] = color
if width > height:
start = (max_dim - height) // 2
result[start : start + height, :, :] = image
else:
start = (max_dim - width) // 2
result[:, start : start + width, :] = image
return result
__all__ = ["BagelImageProcessor"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Bagel.
"""

View File

View File

@ -0,0 +1,557 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Janus model."""
import re
import tempfile
import unittest
from functools import reduce
import numpy as np
import requests
from transformers import (
AutoProcessor,
BagelConfig,
JanusForConditionalGeneration,
JanusModel,
JanusVQVAE,
JanusVQVAEConfig,
is_torch_available,
is_vision_available,
)
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.testing_utils import (
require_torch,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
class JanusVisionText2TextModelTester:
def __init__(
self,
parent,
image_token_index=0,
seq_length=25,
initializer_range=0.02,
text_config={
"model_type": "llama",
"seq_length": 7,
"is_training": True,
"use_input_mask": True,
"use_token_type_ids": False,
"use_labels": True,
"vocab_size": 99,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 16,
"type_sequence_label_size": 2,
"initializer_range": 0.02,
"num_labels": 3,
"num_choices": 4,
"pad_token_id": 1,
},
is_training=True,
vision_config={
"use_labels": True,
"image_size": 20,
"patch_size": 5,
"num_image_tokens": 4,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"projection_dim": 32,
"num_key_value_heads": 1,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"mlp_ratio": 2,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
"vision_feature_select_strategy": "default",
"vision_feature_layer": -1,
},
use_cache=False,
vq_num_embeds=12,
vq_embed_dim=12,
vq_channel_multiplier=[1, 1],
):
self.parent = parent
self.initializer_range = initializer_range
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.image_token_index = image_token_index
self.text_config = text_config
self.vision_config = vision_config
self.seq_length = seq_length
self.pad_token_id = text_config["pad_token_id"]
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = vision_config["num_channels"]
self.image_size = vision_config["image_size"]
self.num_image_tokens = vision_config["num_image_tokens"]
self.use_cache = use_cache
# vq model params
self.vq_num_embeds = vq_num_embeds
self.vq_embed_dim = vq_embed_dim
self.vq_channel_multiplier = vq_channel_multiplier
def get_vq_config(self):
return {
"embed_dim": self.vq_embed_dim,
"num_embeddings": self.vq_num_embeds,
"latent_channels": self.vq_embed_dim,
"in_channels": 3,
"base_channels": 32, # we have a GroupNorm of 32 groups, so can't do less
"channel_multiplier": self.vq_channel_multiplier,
"initializer_range": self.initializer_range,
"projection_dim": 10,
"image_token_embed_dim": 32, # Same as text model hidden size
}
def get_config(self):
return BagelConfig(
text_config=self.text_config,
vision_config=self.vision_config,
vq_config=self.get_vq_config(),
)
def prepare_config_and_inputs(self):
config = self.get_config()
pixel_values = floats_tensor(
[
self.batch_size,
3,
self.image_size,
self.image_size,
]
)
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == self.image_token_index] = self.pad_token_id
input_ids[:, : self.num_image_tokens] = self.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
"generation_mode": "text", # Required to perform text generation instead of image generation.
}
return config, inputs_dict
@require_torch
class JanusVisionText2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (JanusModel, JanusForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (JanusForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = JanusVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=BagelConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["generation_mode"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# Overwrite inputs_embeds tests because we need to delete "pixel values" for VLMs.
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
del inputs["generation_mode"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)
def test_sdpa_can_dispatch_composite_models(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# Load the model with SDPA
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
# Load model with eager attention
model_eager = model_class.from_pretrained(
tmpdirname,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
# SigLip has one shared cls attr for all models, so we assign both submodels heer
vision_attn = language_attn = "sdpa" if model._supports_sdpa else "eager"
if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "language_model"):
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn)
self.assertTrue(model_sdpa.language_model.config._attn_implementation == language_attn)
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
self.assertTrue(model_eager.language_model.config._attn_implementation == "eager")
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if any(re.finditer(r"Attention(?!Pool)", class_name)):
self.assertTrue(submodule.config._attn_implementation == "eager")
for name, submodule in model_sdpa.named_modules():
class_name = submodule.__class__.__name__
if any(re.finditer(r"Attention(?!Pool)", class_name)):
self.assertTrue(submodule.config._attn_implementation == "sdpa")
def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=None):
if not self.model_tester.is_training:
self.skipTest(reason="ModelTester is not configured to run training tests")
"""
We skip some parameters when checking for gradient checkpointing:
- VQ model, as its training is not supported.
- A few other modules used for image generation.
"""
skip_patterns = ["vqmodel", "generation_embeddings", "generation_aligner", "generation_head"]
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
if (
model_class.__name__
in [
*get_values(MODEL_MAPPING_NAMES),
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
]
or not model_class.supports_gradient_checkpointing
):
# TODO (ydshieh): use `skipTest` once pytest-dev/pytest-subtests/pull/169 is merged
# self.skipTest(reason=f"`supports_gradient_checkpointing` is False for {model_class.__name__}.")
continue
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model.train()
# unfreeze additional layers
for p in model.parameters():
p.requires_grad_(True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
optimizer.step()
if self.test_all_params_have_gradient:
for k, v in model.named_parameters():
if v.requires_grad and not reduce(lambda t, s: t | (s in k), skip_patterns, False):
self.assertTrue(v.grad is not None, f"{k} in {model_class.__name__} has no gradient!")
else:
pass
@unittest.skip("There are recompilations in Janus") # TODO (joao, raushan): fix me
def test_generate_compile_model_forward(self):
pass
class JanusVQModelTester:
def __init__(
self,
parent,
batch_size=5,
is_training=False,
initializer_range=0.02,
image_size=30,
num_embeds=12,
base_channels=32, # we have a GroupNorm of 32 groups, so can't do less
embed_dim=12,
channel_multiplier=[1, 2],
patch_size=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.initializer_range = initializer_range
self.image_size = image_size
self.base_channels = base_channels
self.num_embeds = num_embeds
self.embed_dim = embed_dim
self.channel_multiplier = channel_multiplier
self.num_patches = image_size // patch_size
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return JanusVQVAEConfig(
embed_dim=self.embed_dim,
num_embeddings=self.num_embeds,
latent_channels=self.embed_dim,
in_channels=3,
base_channels=self.base_channels,
channel_multiplier=self.channel_multiplier,
initializer_range=self.initializer_range,
resolution=self.image_size,
num_patches=self.num_patches,
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class JanusVQModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (JanusVQVAE,) if is_torch_available() else ()
test_head_masking = False
test_pruning = False
fx_compatible = False
has_attentions = False
test_resize_embeddings = False
def setUp(self):
self.model_tester = JanusVQModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=JanusVQVAEConfig,
has_text_modality=False,
common_properties=["embed_dim", "num_embeddings"],
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
def test_cpu_offload(self):
pass
@unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
def test_disk_offload_bin(self):
pass
@unittest.skip("Janus VQ module cannot offload due to using `self.weight` directly")
def test_disk_offload_safetensors(self):
pass
@unittest.skip("Janus VQ module has no hidden states")
def test_hidden_states_output(self):
pass
@unittest.skip("Janus VQ module has no hidden states")
def test_model_outputs_equivalence(self):
pass
@unittest.skip("Janus VQ module has no get/set embeddings method")
def test_model_get_set_embeddings(self):
pass
@unittest.skip("Janus VQ module has no hidden states")
def test_retain_grad_hidden_states_attentions(self):
pass
class JanusIntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "deepseek-community/Janus-Pro-1B"
@slow
def test_model_text_generation(self):
model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
model.eval()
processor = AutoProcessor.from_pretrained(self.model_id)
image = Image.open(
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
)
prompt = "<image_placeholder>\nDescribe what do you see here and tell me about the history behind it?"
inputs = processor(images=image, text=prompt, generation_mode="text", return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=20, generation_mode="text", do_sample=False)
EXPECTED_DECODED_TEXT = 'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\n\nDescribe what do you see here and tell me about the history behind it?\n\nThe image depicts the constellation of Leo, which is often referred to as the "Lion"' # fmt: skip
text = processor.decode(output[0], skip_special_tokens=True)
self.assertEqual(
text,
EXPECTED_DECODED_TEXT,
)
@slow
def test_model_text_generation_batched(self):
model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(self.model_id)
image_1 = Image.open(
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
)
image_2 = Image.open(
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
)
prompts = [
"<image_placeholder>\nDescribe what do you see here and tell me about the history behind it?",
"What constellation is this image showing?<image_placeholder>\n",
]
inputs = processor(
images=[image_1, image_2], text=prompts, generation_mode="text", padding=True, return_tensors="pt"
).to(model.device, torch.float16)
EXPECTED_TEXT_COMPLETION = [
'You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\n\nDescribe what do you see here and tell me about the history behind it?\n\nThe image depicts the constellation of Leo, which is often referred to as the "Lion"',
"You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\nWhat constellation is this image showing?\n\nThe image shows a constellation that is shaped like a stylized figure with a long tail. This",
]
generated_ids = model.generate(**inputs, max_new_tokens=20, generation_mode="text", do_sample=False)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_model_text_generation_with_multi_image(self):
model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(self.model_id)
image_1 = Image.open(
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
)
image_2 = Image.open(
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
)
prompt = "What do these two images <image_placeholder> and <image_placeholder> have in common?"
inputs = processor(images=[image_1, image_2], text=prompt, generation_mode="text", return_tensors="pt").to(
model.device, torch.float16
)
EXPECTED_TEXT_COMPLETION = ['You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n\nWhat do these two images and have in common?\n\nThe two images you provided are of the same constellation. The first image shows the constellation of Leo, and the second image shows the constellation of Ursa Major. Both constellations are part of'] # fmt: skip
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
@slow
def test_model_generate_images(self):
model = JanusForConditionalGeneration.from_pretrained(self.model_id, device_map="auto")
processor = AutoProcessor.from_pretrained(self.model_id)
inputs = processor(
text=["A portrait of young girl. masterpiece, film grained, best quality."],
padding=True,
generation_mode="image",
return_tensors="pt",
).to(model.device)
self.assertTrue(inputs.input_ids.shape[1] == 17)
out = model.generate(
**inputs,
generation_mode="image",
do_sample=False,
)
# It should run for num_image_tokens in this case 576.
self.assertTrue(out.shape[1] == 576)
# fmt: off
expected_tokens = torch.tensor([4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971,
14985, 14834, 15438, 7548, 1820, 1465, 13529, 12761, 10503, 12761,
14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713,
14305, 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297,
1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676,
]).to(model.device)
# fmt: on
# Compare the first 50 generated tokens.
self.assertTrue(torch.allclose(expected_tokens, out[0][:50]))
# Decode generated tokens to pixel values and postprocess them.
decoded_pixel_values = model.decode_image_tokens(out)
images = processor.postprocess(list(decoded_pixel_values.float()), return_tensors="np")
self.assertTrue(images["pixel_values"].shape == (1, 384, 384, 3))
self.assertTrue(isinstance(images["pixel_values"], np.ndarray))