Add MLCD model (#36182)

* Add MLCD model

* Update codes for auto-mapping

* Add test scripts for MLCD

* Update doc for MLCD model

* Fix import error

* Fix import error

* Fix CI error for attention_outputs

* Fix code style for CI

* Fix code style for CI

* Fix code style for CI

* Fix code style for CI

* Fix code style for CI

* Fix CI error for initialization

* Fix code style for CI

* Fix code style for CI

* Reformat codes and docs for CI test

* Reformat codes and docs for CI test

* Remove unused attributes for CI test

* Fix style for CI test

* List MLCD in flash_attn doc

* Fix: typos, modulars, refactors from suggestions

* Refactoring convert_mlcd_weights_to_hf.py from suggestions

* Fix: docs conflicts

* Fix error for CI test

* Fix style for CI test

* Add integration test for MLCD

* Refactoring by class inheritance

* Fix: refactor attention interface, adjust codes

* Fix: merging conflicts

* Fix: merging conflicts

* Fix: style for CI test

* Fix: style for CI test

* Fix: set test_resize_embeddings to be False

* Fix: initializer for CI test

* Fix: conflicts, CI test, warning and refactoring

* Fix: merging conflicts

* Refactor

* Update docs

* Fix mistakes

* Remove unused args and fix multi-gpu error

* Revert position_embeddings

* Solve conflicts

* Solve conflicts

* Remove dummy

* Update _init_weights

* Update _init_weights

* Update _init_weights for CI test
This commit is contained in:
Huajie Tan 2025-04-15 18:33:09 +08:00 committed by GitHub
parent d6ac923ad9
commit 6f7ea1cf00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2066 additions and 0 deletions

View File

@ -737,6 +737,8 @@
title: Mask2Former
- local: model_doc/maskformer
title: MaskFormer
- local: model_doc/mlcd
title: MLCD
- local: model_doc/mobilenet_v1
title: MobileNetV1
- local: model_doc/mobilenet_v2

View File

@ -0,0 +1,81 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# MLCD
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
## Overview
The MLCD models were released by the DeepGlint-AI team in [unicom](https://github.com/deepglint/unicom), which focuses on building foundational visual models for large multimodal language models using large-scale datasets such as LAION400M and COYO700M, and employs sample-to-cluster contrastive learning to optimize performance. MLCD models are primarily used for multimodal visual large language models, such as LLaVA.
🔥**MLCD-ViT-bigG**🔥 series is the state-of-the-art vision transformer model enhanced with 2D Rotary Position Embedding (RoPE2D), achieving superior performance on document understanding and visual question answering tasks. Developed by DeepGlint AI, this model demonstrates exceptional capabilities in processing complex visual-language interactions.
Tips:
- We adopted the official [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT) and the official training dataset [LLaVA-NeXT-Data](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Data) for evaluating the foundational visual models.
- The language model is [Qwen2.5-7B](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct).
Result:
| Vision Tower | RoPE2D | ChartQA | DocVQA | InfoVQA | OCRBench | MMMU |
| :-------------------------------------------------------------------------------------------- | :----: | :-------- | :-------- | :-------- | :--------- | :-------- |
| CLIP (ViT-L-14-336px) | × | 66.52 | 75.21 | 38.88 | 525.00 | 44.20 |
| SigLIP (ViT-SO400M-384px) | × | 69.28 | 76.71 | 41.38 | 554.00 | 46.78 |
| DFN5B (ViT-H-14-378px) | × | 64.36 | 70.87 | 38.59 | 473.00 | **48.00** |
| **[MLCD (ViT-L-14-336px)](https://huggingface.co/DeepGlint-AI/mlcd-vit-large-patch14-336)** | × | 67.84 | 76.46 | 43.48 | 531.00 | 44.30 |
| **[MLCD (ViT-bigG-14-336px)](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336)** | √ | 71.07 | 79.63 | 44.38 | 572.00 | 46.78 |
| **[MLCD (ViT-bigG-14-448px)](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-448)** | √ | **73.80** | **83.34** | **46.59** | **582.00** | 46.00 |
## Usage
```python
import requests
from PIL import Image
from transformers import AutoProcessor, MLCDVisionModel
# Load model and processor
model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
# Process single image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")
# Generate outputs
with torch.no_grad():
outputs = model(**inputs)
# Get visual features
features = outputs.last_hidden_state
print(f"Extracted features shape: {features.shape}")
```
## MLCDVisionConfig
[[autodoc]] MLCDVisionConfig
## MLCDVisionModel
[[autodoc]] MLCDVisionModel
- forward

View File

@ -179,6 +179,7 @@ if TYPE_CHECKING:
from .mistral import *
from .mistral3 import *
from .mixtral import *
from .mlcd import *
from .mllama import *
from .mluke import *
from .mobilebert import *

View File

@ -200,6 +200,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("mistral", "MistralConfig"),
("mistral3", "Mistral3Config"),
("mixtral", "MixtralConfig"),
("mlcd", "MLCDVisionConfig"),
("mllama", "MllamaConfig"),
("mobilebert", "MobileBertConfig"),
("mobilenet_v1", "MobileNetV1Config"),
@ -559,6 +560,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("mistral", "Mistral"),
("mistral3", "Mistral3"),
("mixtral", "Mixtral"),
("mlcd", "MLCD"),
("mllama", "Mllama"),
("mluke", "mLUKE"),
("mms", "MMS"),

View File

@ -114,6 +114,7 @@ else:
("maskformer", ("MaskFormerImageProcessor",)),
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("mllama", ("MllamaImageProcessor",)),
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),

View File

@ -183,6 +183,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("mimi", "MimiModel"),
("mistral", "MistralModel"),
("mixtral", "MixtralModel"),
("mlcd", "MLCDVisionModel"),
("mobilebert", "MobileBertModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@ -640,6 +641,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("llama4", "Llama4VisionModel"),
("mlcd", "MLCDVisionModel"),
("mllama", "MllamaVisionModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),

View File

@ -0,0 +1,27 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_mlcd import *
from .modeling_mlcd import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,117 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_mlcd.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
class MLCDVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD
vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the vision encoder of the MLCD
[DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1664):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
projection_dim (`int`, *optional*, defaults to 1024):
Dimensionality of text and vision projection layers.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
image_size (`int`, *optional*, defaults to 336):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 1.0):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
Example:
```python
>>> from transformers import MLCDVisionConfig, MLCDVisionModel
>>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
>>> configuration = MLCDVisionConfig()
>>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
>>> model = MLCDVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mlcd_vision_model"
base_config_key = "vision_config"
def __init__(
self,
hidden_size=1664,
intermediate_size=8192,
num_hidden_layers=48,
num_attention_heads=16,
num_key_value_groups=1,
num_channels=3,
image_size=336,
patch_size=14,
hidden_act="gelu",
layer_norm_eps=1e-5,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_groups = num_key_value_groups
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
__all__ = ["MLCDVisionConfig"]

View File

@ -0,0 +1,336 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert MLCD checkpoints from the original repository.
URL: https://github.com/deepglint/unicom/tree/main
"""
import argparse
import collections
import os
import re
import numpy as np
import requests
import torch
from PIL import Image
from transformers import CLIPImageProcessor
from ...utils import logging
from .configuration_mlcd import MLCDVisionConfig
from .modeling_mlcd import MLCDVisionModel
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
COMMON_CONFIG_PARAMS = {
"mlcd-vit-bigG-patch14-336": {
"hidden_size": 1664,
"image_size": 336,
"intermediate_size": 8192,
"num_attention_heads": 16,
"num_hidden_layers": 48,
"patch_size": 14,
"projection_dim": 1024,
},
"mlcd-vit-bigG-patch14-448": {
"hidden_size": 1664,
"image_size": 448,
"intermediate_size": 8192,
"num_attention_heads": 16,
"num_hidden_layers": 48,
"patch_size": 14,
"projection_dim": 1024,
},
}
MODEL_NAME_TO_CHECKPOINT_PATH = {
# base checkpoints
"mlcd-vit-bigG-patch14-336": "MLCD_ViT_bigG_14_336px_pytorch.pt",
"mlcd-vit-bigG-patch14-448": "MLCD_ViT_bigG_14_448px_pytorch.pt",
}
# fmt: off
EXPECTED_OUTPUTS = {
"mlcd-vit-bigG-patch14-336": torch.tensor([
[-0.8921, -0.1069, 0.2989, 0.6018, -0.5892],
[ 0.4093, -1.4592, 0.6048, -0.5147, -0.5929],
[ 0.7796, -0.7133, -0.5649, -0.7843, -0.5548],
[ 0.0041, 0.0286, 0.4310, -0.1403, -0.2399],
[ 0.0839, 0.2152, -0.3822, -0.1668, -0.7886]
]),
"mlcd-vit-bigG-patch14-448": torch.tensor([
[-0.8978, -0.1181, 0.4769, 0.4761, -0.5779],
[ 0.2640, -2.6150, 0.4853, 0.5743, -1.1003],
[ 0.3314, -0.3328, -0.4305, -0.1874, -0.7701],
[-1.5174, -1.0238, -1.1854, 0.1749, -0.8786],
[ 0.2323, -0.8346, -0.9680, -0.2951, 0.0867],
]),
}
# fmt: on
# fmt: off
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
# Vision embeddings
r"conv1.weight": r"vision_model.embeddings.patch_embedding.weight",
r"class_embedding": r"vision_model.embeddings.class_embedding",
r"vision_rotary_embedding": r"vision_model.vision_rotary_embedding",
r"class_pos_emb": r"vision_model.class_pos_emb",
# Vision encoder
r"transformer.resblocks_(\d+).ln_1.weight": r"vision_model.encoder.layers.\1.layer_norm1.weight",
r"transformer.resblocks_(\d+).ln_1.bias": r"vision_model.encoder.layers.\1.layer_norm1.bias",
r"transformer.resblocks_(\d+).ln_2.weight": r"vision_model.encoder.layers.\1.layer_norm2.weight",
r"transformer.resblocks_(\d+).ln_2.bias": r"vision_model.encoder.layers.\1.layer_norm2.bias",
r"transformer.resblocks_(\d+).mlp.c_fc.weight": r"vision_model.encoder.layers.\1.mlp.fc1.weight",
r"transformer.resblocks_(\d+).mlp.c_fc.bias": r"vision_model.encoder.layers.\1.mlp.fc1.bias",
r"transformer.resblocks_(\d+).mlp.c_proj.weight": r"vision_model.encoder.layers.\1.mlp.fc2.weight",
r"transformer.resblocks_(\d+).mlp.c_proj.bias": r"vision_model.encoder.layers.\1.mlp.fc2.bias",
r"transformer.resblocks_(\d+).attn.(q|k|v|out)_proj.weight": r"vision_model.encoder.layers.\1.self_attn.\2_proj.weight",
r"transformer.resblocks_(\d+).attn.(q|k|v|out)_proj.bias": r"vision_model.encoder.layers.\1.self_attn.\2_proj.bias",
# Vision norm
r"ln_post.weight": r"vision_model.post_layernorm.weight",
r"ln_post.bias": r"vision_model.post_layernorm.bias",
r"ln_pre.weight": r"vision_model.pre_layernorm.weight",
r"ln_pre.bias": r"vision_model.pre_layernorm.bias",
}
# fmt: on
# --------------------------------------------------------------------------------------------
# Model objects: configuration, image processor
# --------------------------------------------------------------------------------------------
def get_mlcd_config(model_name: str) -> MLCDVisionConfig:
"""
Create a configuration for the MLCD model based on the model name.
"""
assert model_name in COMMON_CONFIG_PARAMS, f"Model {model_name} not found in the list of COMMON_CONFIG_PARAMS."
config_params = COMMON_CONFIG_PARAMS[model_name]
config = MLCDVisionConfig(
hidden_size=config_params["hidden_size"],
image_size=config_params["image_size"],
intermediate_size=config_params["intermediate_size"],
num_attention_heads=config_params["num_attention_heads"],
num_hidden_layers=config_params["num_hidden_layers"],
patch_size=config_params["patch_size"],
projection_dim=config_params["projection_dim"],
)
return config
def get_mlcd_image_processor(model_name: str) -> CLIPImageProcessor:
"""
Create an image processor for the MLCD model based on the model name.
"""
assert model_name in COMMON_CONFIG_PARAMS, f"Model {model_name} not found in the list of COMMON_CONFIG_PARAMS."
config_params = COMMON_CONFIG_PARAMS[model_name]
image_processor = CLIPImageProcessor(
do_center_crop=True,
do_normalize=True,
do_resize=True,
feature_extractor_type="CLIPFeatureExtractor",
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=config_params["image_size"],
crop_size=config_params["image_size"],
)
return image_processor
# --------------------------------------------------------------------------------------------
# Helper functions for state dict conversion
# --------------------------------------------------------------------------------------------
def flatten_nested_dict(params: dict, parent_key: str = "", sep: str = ".") -> dict:
"""
Flatten a nested original checkpoint dictionary into a flat dictionary.
"""
items = []
for k, v in params.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def split_resblocks_layers(state_dict: dict) -> dict:
"""
Split the resblocks weight into layers. In some cases they are concatenated in
the original checkpoints.
"""
# Make shallow copy
state_dict = state_dict.copy()
# Split resblocks weight into layers
keys = list(state_dict.keys())
for key in keys:
if ".resblocks." in key:
weight = state_dict.pop(key)
for i, weight_i in enumerate(weight):
new_name = key.replace("resblocks", f"resblocks_{i}")
state_dict[new_name] = weight_i
return state_dict
def chunk_qkv_for_attn(state_dict: dict) -> dict:
"""
Chunk the q/k/v weights and biases for the attention layers.
"""
# Make shallow copy
state_dict = state_dict.copy()
# Read and process q/k/v weights and biases
keys = list(state_dict.keys())
for key in keys:
if ".in_proj." in key:
weight = state_dict.pop(key)
qkv_weights = weight.chunk(3, dim=0)
for name, weight_i in zip(["q_proj", "k_proj", "v_proj"], qkv_weights):
new_name = key.replace("in_proj", name)
state_dict[new_name] = weight_i
return state_dict
def convert_old_keys_to_new_keys(state_dict_keys: list) -> dict:
"""
This function should be applied only once, on the concatenated keys to efficiently rename using
the key mappings.
"""
output_dict = {}
if state_dict_keys is not None:
old_text = "\n".join(state_dict_keys)
new_text = old_text
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
if replacement is None:
new_text = re.sub(pattern, "", new_text) # an empty line
continue
new_text = re.sub(pattern, replacement, new_text)
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
return output_dict
# --------------------------------------------------------------------------------------------
# Convert model
# --------------------------------------------------------------------------------------------
@torch.no_grad()
def convert_mlcd_checkpoint(model_name, input_dir, output_dir, verify_hidden_state=True, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our MLCD structure.
"""
# Define MLCD configuration
config = get_mlcd_config(model_name)
checkpoint = MODEL_NAME_TO_CHECKPOINT_PATH[model_name]
checkpoint_path = os.path.join(input_dir, checkpoint)
assert os.path.exists(checkpoint_path), f"Checkpoint path ({checkpoint_path}) not found."
# Load original checkpoint
print(f"Loading checkpoint from {checkpoint_path}...")
state_dict = torch.load(checkpoint_path, "cpu")
# Flatten nested dictionary
print("Flattening nested dictionary...")
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
if "positional_embedding" in state_dict:
state_dict.pop("positional_embedding")
state_dict = flatten_nested_dict(state_dict)
state_dict = split_resblocks_layers(state_dict)
state_dict = chunk_qkv_for_attn(state_dict)
# Rename and transform weights
print("Renaming and transforming weights...")
original_keys = list(state_dict.keys())
hf_keys = convert_old_keys_to_new_keys(original_keys)
new_state_dict = {}
for original_key in original_keys:
new_key = hf_keys[original_key]
parameter = state_dict.pop(original_key)
new_state_dict[new_key] = torch.from_numpy(parameter)
# load HuggingFace model
print("Loading HuggingFace model...")
model = MLCDVisionModel(config).eval()
model.load_state_dict(new_state_dict)
# Create processor
print("Creating processor...")
image_processor = get_mlcd_image_processor(model_name)
# Verify hidden state
if verify_hidden_state:
print("Verifying hidden state for {model_name}...")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
last_hidden_state = model(pixel_values, output_hidden_states=True).last_hidden_state[0, :5, :5]
expected_hidden_state = EXPECTED_OUTPUTS[model_name]
np.testing.assert_allclose(last_hidden_state.cpu().numpy(), expected_hidden_state.numpy(), atol=1e-4)
# Save model
if output_dir is not None:
dst_dir = os.path.join(output_dir, model_name)
print(f"Saving model {model_name} to {dst_dir}...")
model.save_pretrained(dst_dir)
print(f"Saving processor to {dst_dir}...")
image_processor.save_pretrained(dst_dir)
if push_to_hub:
print(f"Pushing model and processor for {model_name} to the HuggingFace Hub...")
model.push_to_hub(f"deepglint-hf/{model_name}", private=True)
image_processor.push_to_hub(f"deepglint-hf/{model_name}", private=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="mlcd-vit-bigG-patch14-448",
type=str,
choices=MODEL_NAME_TO_CHECKPOINT_PATH.keys(),
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--input_dir",
default="mlcd/original",
help="Location of MLCD original weights",
)
parser.add_argument(
"--output_dir",
default="mlcd/checkpoint",
help="Location to write HF model and processor",
)
parser.add_argument(
"--verify_hidden_state",
action="store_true",
help="Whether to verify hidden_state against the original implementation.",
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_mlcd_checkpoint(
args.model_name, args.input_dir, args.output_dir, args.verify_hidden_state, args.push_to_hub
)

View File

@ -0,0 +1,679 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_mlcd.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...activations import ACT2FN
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
torch_int,
)
from .configuration_mlcd import MLCDVisionConfig
logger = logging.get_logger(__name__)
class MLCDMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MLCDRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
"""
Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
Args:
num_patches_height (int): Number of patches in the height dimension.
num_patches_width (int): Number of patches in the width dimension.
Returns:
torch.Tensor: Rotary positional embeddings for the given grid size.
"""
# Generate position IDs for height and width dimensions
hpos_ids = (
torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
)
wpos_ids = (
torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
)
# Flatten and stack the position IDs
pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
# Generate the full rotary positional embeddings for the maximum grid size
max_grid_size = max(num_patches_height, num_patches_width)
seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
# Select and flatten the embeddings based on the position IDs
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
class MLCDVisionEmbeddings(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1] - 1
position_embedding = self.position_embedding.weight.unsqueeze(0)
num_positions = position_embedding.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embedding(self.position_ids)
class_pos_embed = position_embedding[:, :1]
patch_pos_embed = position_embedding[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
# patch_embeds -> shape = [batch, width, grid, grid]
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
return embeddings
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class MLCDAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper
Multi-headed attention with RoPE. Refer to papers:
- Attention is all you need:
https://arxiv.org/abs/1706.03762
- RoFormer: Enhanced Transformer with Rotary Position Embedding:
https://arxiv.org/abs/2104.09864
"""
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.num_key_value_groups = config.num_key_value_groups
self.is_causal = False
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, seq_length = hidden_states.shape[:-1]
# Each of shape: [batch_size, seq_length, num_heads, head_dim]
query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
# Apply positional embeddings
cos = position_embeddings[0].unsqueeze(0).float()
sin = position_embeddings[1].unsqueeze(0).float()
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
# Each of shape: [batch_size, num_heads, seq_length, head_dim]
query_states = query_states.permute(0, 2, 1, 3).contiguous()
key_states = key_states.permute(0, 2, 1, 3).contiguous()
value_states = value_states.permute(0, 2, 1, 3).contiguous()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scale,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
attn_output = self.out_proj(attn_output)
attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
return attn_output, attn_weights
class MLCDEncoderLayer(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MLCDAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = MLCDMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
Represents the hidden states from the previous layer or the input embeddings.
position_embeddings (`Tuple[torch.Tensor, torch.Tensor]`):
A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
Represents absolute positional embeddings for the query and key in the attention mechanism.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class MLCDEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`MLCDEncoderLayer`].
Args:
config: MLCDVisionConfig
"""
def __init__(self, config: MLCDVisionConfig):
"""Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
super().__init__()
self.config = config
self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@can_return_tuple
def forward(
self,
inputs_embeds: torch.FloatTensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
position_embeddings (`Tuple[torch.Tensor, torch.Tensor]`):
A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
Represents absolute positional embeddings for the query and key in the attention mechanism.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
position_embeddings,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
MLCD_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class MLCDVisionTransformer(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = MLCDVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = MLCDEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
@add_start_docstrings_to_model_forward(MLCD_VISION_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
num_patches_height = pixel_values.shape[-2] // self.config.patch_size
num_patches_width = pixel_values.shape[-1] // self.config.patch_size
rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class MLCDPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MLCDVisionConfig
base_model_prefix = "mlcd"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, MLCDVisionEmbeddings):
factor = self.config.initializer_factor
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
elif isinstance(module, MLCDAttention):
factor = self.config.initializer_factor
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (module.embed_dim**-0.5) * factor
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
elif isinstance(module, MLCDMLP):
factor = self.config.initializer_factor
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
elif isinstance(module, MLCDVisionTransformer):
factor = self.config.initializer_factor
pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
MLCD_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MLCDVisionConfig`]):
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare MLCD vision encoder outputting raw hidden-states without any specific head on top.",
MLCD_START_DOCSTRING,
)
class MLCDVisionModel(MLCDPreTrainedModel):
config_class = MLCDVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["MLCDEncoderLayer"]
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
self.vision_model = MLCDVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@add_start_docstrings_to_model_forward(MLCD_VISION_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> import requests
>>> from PIL import Image
>>> from transformers import AutoProcessor, MLCDVisionModel
>>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
>>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs, output_attentions=True)
>>> features = outputs.last_hidden_state
>>> print(f"Extracted features shape: {features.shape}")
>>> print(f"Number of attention layers: {len(outputs.attentions)}")
>>> print(f"Attention shape: {outputs.attentions[0].shape}")
```"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
__all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"]

View File

@ -0,0 +1,596 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
)
from ...modeling_utils import (
ALL_ATTENTION_FUNCTIONS,
PreTrainedModel,
)
from ...processing_utils import Unpack
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from ..clip.modeling_clip import (
CLIPMLP,
CLIPAttention,
CLIPEncoder,
CLIPEncoderLayer,
CLIPVisionEmbeddings,
CLIPVisionModel,
CLIPVisionTransformer,
)
from ..llama.modeling_llama import eager_attention_forward
from ..qwen2_vl.modeling_qwen2_vl import (
VisionRotaryEmbedding,
apply_rotary_pos_emb_vision,
)
logger = logging.get_logger(__name__)
class MLCDVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD
vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the vision encoder of the MLCD
[DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1664):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
projection_dim (`int`, *optional*, defaults to 1024):
Dimensionality of text and vision projection layers.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
image_size (`int`, *optional*, defaults to 336):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 1.0):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
Example:
```python
>>> from transformers import MLCDVisionConfig, MLCDVisionModel
>>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
>>> configuration = MLCDVisionConfig()
>>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
>>> model = MLCDVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mlcd_vision_model"
base_config_key = "vision_config"
def __init__(
self,
hidden_size=1664,
intermediate_size=8192,
num_hidden_layers=48,
num_attention_heads=16,
num_key_value_groups=1,
num_channels=3,
image_size=336,
patch_size=14,
hidden_act="gelu",
layer_norm_eps=1e-5,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_groups = num_key_value_groups
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
class MLCDMLP(CLIPMLP):
pass
class MLCDRotaryEmbedding(VisionRotaryEmbedding):
def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
"""
Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
Args:
num_patches_height (int): Number of patches in the height dimension.
num_patches_width (int): Number of patches in the width dimension.
Returns:
torch.Tensor: Rotary positional embeddings for the given grid size.
"""
# Generate position IDs for height and width dimensions
hpos_ids = (
torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
)
wpos_ids = (
torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
)
# Flatten and stack the position IDs
pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
# Generate the full rotary positional embeddings for the maximum grid size
max_grid_size = max(num_patches_height, num_patches_width)
seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
# Select and flatten the embeddings based on the position IDs
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
class MLCDVisionEmbeddings(CLIPVisionEmbeddings):
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
del self.position_embedding
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
# patch_embeds -> shape = [batch, width, grid, grid]
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
return embeddings
class MLCDAttention(CLIPAttention):
"""Multi-headed attention with RoPE. Refer to papers:
- Attention is all you need:
https://arxiv.org/abs/1706.03762
- RoFormer: Enhanced Transformer with Rotary Position Embedding:
https://arxiv.org/abs/2104.09864
"""
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
self.num_key_value_groups = config.num_key_value_groups
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seq_length = hidden_states.shape[:-1]
# Each of shape: [batch_size, seq_length, num_heads, head_dim]
query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
# Apply positional embeddings
cos = position_embeddings[0].unsqueeze(0).float()
sin = position_embeddings[1].unsqueeze(0).float()
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
# Each of shape: [batch_size, num_heads, seq_length, head_dim]
query_states = query_states.permute(0, 2, 1, 3).contiguous()
key_states = key_states.permute(0, 2, 1, 3).contiguous()
value_states = value_states.permute(0, 2, 1, 3).contiguous()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.dropout,
scaling=self.scale,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
attn_output = self.out_proj(attn_output)
attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
return attn_output, attn_weights
class MLCDEncoderLayer(CLIPEncoderLayer):
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
self.self_attn = MLCDAttention(config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
Represents the hidden states from the previous layer or the input embeddings.
position_embeddings (`Tuple[torch.Tensor, torch.Tensor]`):
A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
Represents absolute positional embeddings for the query and key in the attention mechanism.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class MLCDEncoder(CLIPEncoder):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`MLCDEncoderLayer`].
Args:
config: MLCDVisionConfig
"""
def __init__(self, config: MLCDVisionConfig):
"""Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
super().__init__(config)
def forward(
self,
inputs_embeds: torch.FloatTensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
position_embeddings (`Tuple[torch.Tensor, torch.Tensor]`):
A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
Represents absolute positional embeddings for the query and key in the attention mechanism.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
position_embeddings,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
MLCD_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class MLCDVisionTransformer(CLIPVisionTransformer):
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
@add_start_docstrings_to_model_forward(MLCD_VISION_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
num_patches_height = pixel_values.shape[-2] // self.config.patch_size
num_patches_width = pixel_values.shape[-1] // self.config.patch_size
rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
MLCD_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MLCDVisionConfig`]):
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
class MLCDPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MLCDVisionConfig
base_model_prefix = "mlcd"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor
if isinstance(module, MLCDVisionEmbeddings):
factor = self.config.initializer_factor
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
elif isinstance(module, MLCDAttention):
factor = self.config.initializer_factor
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
out_proj_std = (module.embed_dim**-0.5) * factor
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
elif isinstance(module, MLCDMLP):
factor = self.config.initializer_factor
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
elif isinstance(module, MLCDVisionTransformer):
factor = self.config.initializer_factor
pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@add_start_docstrings(
"The bare MLCD vision encoder outputting raw hidden-states without any specific head on top.",
MLCD_START_DOCSTRING,
)
class MLCDVisionModel(CLIPVisionModel):
@add_start_docstrings_to_model_forward(MLCD_VISION_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
```python
>>> import requests
>>> from PIL import Image
>>> from transformers import AutoProcessor, MLCDVisionModel
>>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
>>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs, output_attentions=True)
>>> features = outputs.last_hidden_state
>>> print(f"Extracted features shape: {features.shape}")
>>> print(f"Number of attention layers: {len(outputs.attentions)}")
>>> print(f"Attention shape: {outputs.attentions[0].shape}")
```"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
__all__ = [
"MLCDVisionConfig",
"MLCDPreTrainedModel",
"MLCDVisionModel",
]

View File

View File

@ -0,0 +1,221 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch MLCD model."""
import unittest
import requests
from PIL import Image
from transformers import (
AutoProcessor,
MLCDVisionConfig,
MLCDVisionModel,
is_torch_available,
)
from transformers.testing_utils import (
require_torch,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
if is_torch_available():
import torch
class MLCDVisionModelTester:
def __init__(
self,
parent,
batch_size=12,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.scope = scope
# in MLCD, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return MLCDVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values):
model = MLCDVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class MLCDVisionModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `MLCDVisionModel`.
"""
all_model_classes = (MLCDVisionModel,) if is_torch_available() else ()
test_pruning = False
test_head_masking = False
test_torchscript = False
test_resize_embeddings = False
test_torch_exportable = True
def setUp(self):
self.model_tester = MLCDVisionModelTester(self)
self.config_tester = ConfigTester(self, config_class=MLCDVisionConfig, has_text_modality=False)
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad and "class_pos_emb" not in name:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@require_torch
class MLCDVisionModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model_name = "DeepGlint-AI/mlcd-vit-bigG-patch14-448"
model = MLCDVisionModel.from_pretrained(model_name).to(torch_device)
processor = AutoProcessor.from_pretrained(model_name)
# process single image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")
# move inputs to the same device as the model
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
# forward pass
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
last_hidden_state = outputs.last_hidden_state
last_attention = outputs.attentions[-1]
# verify the shapes of last_hidden_state and last_attention
self.assertEqual(
last_hidden_state.shape,
torch.Size([1, 1025, 1664]),
)
self.assertEqual(
last_attention.shape,
torch.Size([1, 16, 1025, 1025]),
)
# verify the values of last_hidden_state and last_attention
# fmt: off
expected_partial_5x5_last_hidden_state = torch.tensor(
[
[-0.8978, -0.1181, 0.4769, 0.4761, -0.5779],
[ 0.2640, -2.6150, 0.4853, 0.5743, -1.1003],
[ 0.3314, -0.3328, -0.4305, -0.1874, -0.7701],
[-1.5174, -1.0238, -1.1854, 0.1749, -0.8786],
[ 0.2323, -0.8346, -0.9680, -0.2951, 0.0867],
]
).to(torch_device)
expected_partial_5x5_last_attention = torch.tensor(
[
[2.0930e-01, 6.3073e-05, 1.4717e-03, 2.6881e-05, 3.0513e-03],
[1.5828e-04, 2.1056e-03, 4.6784e-04, 1.8276e-03, 5.3233e-04],
[5.7824e-04, 1.1446e-03, 1.3854e-03, 1.1775e-03, 1.2750e-03],
[9.6343e-05, 1.6365e-03, 2.9066e-04, 3.1089e-03, 2.0607e-04],
[6.2688e-04, 1.1656e-03, 1.5030e-03, 8.2819e-04, 2.6992e-03],
]
).to(torch_device)
# fmt: on
torch.testing.assert_close(
last_hidden_state[0, :5, :5], expected_partial_5x5_last_hidden_state, rtol=1e-3, atol=1e-3
)
torch.testing.assert_close(
last_attention[0, 0, :5, :5], expected_partial_5x5_last_attention, rtol=1e-4, atol=1e-4
)

View File

@ -383,6 +383,7 @@ OBJECTS_TO_IGNORE = [
"MegatronBertConfig",
"MegatronBertForPreTraining",
"MegatronBertModel",
"MLCDVisionConfig",
"MobileBertConfig",
"MobileBertModel",
"MobileBertTokenizerFast",