mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add MLCD model (#36182)
* Add MLCD model * Update codes for auto-mapping * Add test scripts for MLCD * Update doc for MLCD model * Fix import error * Fix import error * Fix CI error for attention_outputs * Fix code style for CI * Fix code style for CI * Fix code style for CI * Fix code style for CI * Fix code style for CI * Fix CI error for initialization * Fix code style for CI * Fix code style for CI * Reformat codes and docs for CI test * Reformat codes and docs for CI test * Remove unused attributes for CI test * Fix style for CI test * List MLCD in flash_attn doc * Fix: typos, modulars, refactors from suggestions * Refactoring convert_mlcd_weights_to_hf.py from suggestions * Fix: docs conflicts * Fix error for CI test * Fix style for CI test * Add integration test for MLCD * Refactoring by class inheritance * Fix: refactor attention interface, adjust codes * Fix: merging conflicts * Fix: merging conflicts * Fix: style for CI test * Fix: style for CI test * Fix: set test_resize_embeddings to be False * Fix: initializer for CI test * Fix: conflicts, CI test, warning and refactoring * Fix: merging conflicts * Refactor * Update docs * Fix mistakes * Remove unused args and fix multi-gpu error * Revert position_embeddings * Solve conflicts * Solve conflicts * Remove dummy * Update _init_weights * Update _init_weights * Update _init_weights for CI test
This commit is contained in:
parent
d6ac923ad9
commit
6f7ea1cf00
@ -737,6 +737,8 @@
|
||||
title: Mask2Former
|
||||
- local: model_doc/maskformer
|
||||
title: MaskFormer
|
||||
- local: model_doc/mlcd
|
||||
title: MLCD
|
||||
- local: model_doc/mobilenet_v1
|
||||
title: MobileNetV1
|
||||
- local: model_doc/mobilenet_v2
|
||||
|
81
docs/source/en/model_doc/mlcd.md
Normal file
81
docs/source/en/model_doc/mlcd.md
Normal file
@ -0,0 +1,81 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# MLCD
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
The MLCD models were released by the DeepGlint-AI team in [unicom](https://github.com/deepglint/unicom), which focuses on building foundational visual models for large multimodal language models using large-scale datasets such as LAION400M and COYO700M, and employs sample-to-cluster contrastive learning to optimize performance. MLCD models are primarily used for multimodal visual large language models, such as LLaVA.
|
||||
|
||||
🔥**MLCD-ViT-bigG**🔥 series is the state-of-the-art vision transformer model enhanced with 2D Rotary Position Embedding (RoPE2D), achieving superior performance on document understanding and visual question answering tasks. Developed by DeepGlint AI, this model demonstrates exceptional capabilities in processing complex visual-language interactions.
|
||||
|
||||
Tips:
|
||||
|
||||
- We adopted the official [LLaVA-NeXT](https://github.com/LLaVA-VL/LLaVA-NeXT) and the official training dataset [LLaVA-NeXT-Data](https://huggingface.co/datasets/lmms-lab/LLaVA-NeXT-Data) for evaluating the foundational visual models.
|
||||
|
||||
- The language model is [Qwen2.5-7B](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct).
|
||||
|
||||
Result:
|
||||
|
||||
| Vision Tower | RoPE2D | ChartQA | DocVQA | InfoVQA | OCRBench | MMMU |
|
||||
| :-------------------------------------------------------------------------------------------- | :----: | :-------- | :-------- | :-------- | :--------- | :-------- |
|
||||
| CLIP (ViT-L-14-336px) | × | 66.52 | 75.21 | 38.88 | 525.00 | 44.20 |
|
||||
| SigLIP (ViT-SO400M-384px) | × | 69.28 | 76.71 | 41.38 | 554.00 | 46.78 |
|
||||
| DFN5B (ViT-H-14-378px) | × | 64.36 | 70.87 | 38.59 | 473.00 | **48.00** |
|
||||
| **[MLCD (ViT-L-14-336px)](https://huggingface.co/DeepGlint-AI/mlcd-vit-large-patch14-336)** | × | 67.84 | 76.46 | 43.48 | 531.00 | 44.30 |
|
||||
| **[MLCD (ViT-bigG-14-336px)](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336)** | √ | 71.07 | 79.63 | 44.38 | 572.00 | 46.78 |
|
||||
| **[MLCD (ViT-bigG-14-448px)](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-448)** | √ | **73.80** | **83.34** | **46.59** | **582.00** | 46.00 |
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, MLCDVisionModel
|
||||
|
||||
# Load model and processor
|
||||
model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
|
||||
processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
|
||||
|
||||
# Process single image
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
# Generate outputs
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# Get visual features
|
||||
features = outputs.last_hidden_state
|
||||
|
||||
print(f"Extracted features shape: {features.shape}")
|
||||
```
|
||||
|
||||
## MLCDVisionConfig
|
||||
|
||||
[[autodoc]] MLCDVisionConfig
|
||||
|
||||
## MLCDVisionModel
|
||||
|
||||
[[autodoc]] MLCDVisionModel
|
||||
- forward
|
@ -179,6 +179,7 @@ if TYPE_CHECKING:
|
||||
from .mistral import *
|
||||
from .mistral3 import *
|
||||
from .mixtral import *
|
||||
from .mlcd import *
|
||||
from .mllama import *
|
||||
from .mluke import *
|
||||
from .mobilebert import *
|
||||
|
@ -200,6 +200,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("mistral", "MistralConfig"),
|
||||
("mistral3", "Mistral3Config"),
|
||||
("mixtral", "MixtralConfig"),
|
||||
("mlcd", "MLCDVisionConfig"),
|
||||
("mllama", "MllamaConfig"),
|
||||
("mobilebert", "MobileBertConfig"),
|
||||
("mobilenet_v1", "MobileNetV1Config"),
|
||||
@ -559,6 +560,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mistral", "Mistral"),
|
||||
("mistral3", "Mistral3"),
|
||||
("mixtral", "Mixtral"),
|
||||
("mlcd", "MLCD"),
|
||||
("mllama", "Mllama"),
|
||||
("mluke", "mLUKE"),
|
||||
("mms", "MMS"),
|
||||
|
@ -114,6 +114,7 @@ else:
|
||||
("maskformer", ("MaskFormerImageProcessor",)),
|
||||
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
|
||||
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
|
||||
("mlcd", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
|
||||
("mllama", ("MllamaImageProcessor",)),
|
||||
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
|
||||
("mobilenet_v2", ("MobileNetV2ImageProcessor", "MobileNetV2ImageProcessorFast")),
|
||||
|
@ -183,6 +183,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("mimi", "MimiModel"),
|
||||
("mistral", "MistralModel"),
|
||||
("mixtral", "MixtralModel"),
|
||||
("mlcd", "MLCDVisionModel"),
|
||||
("mobilebert", "MobileBertModel"),
|
||||
("mobilenet_v1", "MobileNetV1Model"),
|
||||
("mobilenet_v2", "MobileNetV2Model"),
|
||||
@ -640,6 +641,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
("imagegpt", "ImageGPTModel"),
|
||||
("levit", "LevitModel"),
|
||||
("llama4", "Llama4VisionModel"),
|
||||
("mlcd", "MLCDVisionModel"),
|
||||
("mllama", "MllamaVisionModel"),
|
||||
("mobilenet_v1", "MobileNetV1Model"),
|
||||
("mobilenet_v2", "MobileNetV2Model"),
|
||||
|
27
src/transformers/models/mlcd/__init__.py
Normal file
27
src/transformers/models/mlcd/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mlcd import *
|
||||
from .modeling_mlcd import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
117
src/transformers/models/mlcd/configuration_mlcd.py
Normal file
117
src/transformers/models/mlcd/configuration_mlcd.py
Normal file
@ -0,0 +1,117 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_mlcd.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class MLCDVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD
|
||||
vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the vision encoder of the MLCD
|
||||
[DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1664):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of text and vision projection layers.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
image_size (`int`, *optional*, defaults to 336):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float`, *optional*, defaults to 1.0):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MLCDVisionConfig, MLCDVisionModel
|
||||
|
||||
>>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
|
||||
>>> configuration = MLCDVisionConfig()
|
||||
|
||||
>>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
|
||||
>>> model = MLCDVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "mlcd_vision_model"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1664,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=48,
|
||||
num_attention_heads=16,
|
||||
num_key_value_groups=1,
|
||||
num_channels=3,
|
||||
image_size=336,
|
||||
patch_size=14,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-5,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_groups = num_key_value_groups
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
|
||||
__all__ = ["MLCDVisionConfig"]
|
336
src/transformers/models/mlcd/convert_mlcd_weights_to_hf.py
Normal file
336
src/transformers/models/mlcd/convert_mlcd_weights_to_hf.py
Normal file
@ -0,0 +1,336 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert MLCD checkpoints from the original repository.
|
||||
|
||||
URL: https://github.com/deepglint/unicom/tree/main
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
from ...utils import logging
|
||||
from .configuration_mlcd import MLCDVisionConfig
|
||||
from .modeling_mlcd import MLCDVisionModel
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
COMMON_CONFIG_PARAMS = {
|
||||
"mlcd-vit-bigG-patch14-336": {
|
||||
"hidden_size": 1664,
|
||||
"image_size": 336,
|
||||
"intermediate_size": 8192,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 48,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 1024,
|
||||
},
|
||||
"mlcd-vit-bigG-patch14-448": {
|
||||
"hidden_size": 1664,
|
||||
"image_size": 448,
|
||||
"intermediate_size": 8192,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 48,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 1024,
|
||||
},
|
||||
}
|
||||
|
||||
MODEL_NAME_TO_CHECKPOINT_PATH = {
|
||||
# base checkpoints
|
||||
"mlcd-vit-bigG-patch14-336": "MLCD_ViT_bigG_14_336px_pytorch.pt",
|
||||
"mlcd-vit-bigG-patch14-448": "MLCD_ViT_bigG_14_448px_pytorch.pt",
|
||||
}
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUTS = {
|
||||
"mlcd-vit-bigG-patch14-336": torch.tensor([
|
||||
[-0.8921, -0.1069, 0.2989, 0.6018, -0.5892],
|
||||
[ 0.4093, -1.4592, 0.6048, -0.5147, -0.5929],
|
||||
[ 0.7796, -0.7133, -0.5649, -0.7843, -0.5548],
|
||||
[ 0.0041, 0.0286, 0.4310, -0.1403, -0.2399],
|
||||
[ 0.0839, 0.2152, -0.3822, -0.1668, -0.7886]
|
||||
]),
|
||||
"mlcd-vit-bigG-patch14-448": torch.tensor([
|
||||
[-0.8978, -0.1181, 0.4769, 0.4761, -0.5779],
|
||||
[ 0.2640, -2.6150, 0.4853, 0.5743, -1.1003],
|
||||
[ 0.3314, -0.3328, -0.4305, -0.1874, -0.7701],
|
||||
[-1.5174, -1.0238, -1.1854, 0.1749, -0.8786],
|
||||
[ 0.2323, -0.8346, -0.9680, -0.2951, 0.0867],
|
||||
]),
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
# fmt: off
|
||||
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||
# Vision embeddings
|
||||
r"conv1.weight": r"vision_model.embeddings.patch_embedding.weight",
|
||||
r"class_embedding": r"vision_model.embeddings.class_embedding",
|
||||
r"vision_rotary_embedding": r"vision_model.vision_rotary_embedding",
|
||||
r"class_pos_emb": r"vision_model.class_pos_emb",
|
||||
# Vision encoder
|
||||
r"transformer.resblocks_(\d+).ln_1.weight": r"vision_model.encoder.layers.\1.layer_norm1.weight",
|
||||
r"transformer.resblocks_(\d+).ln_1.bias": r"vision_model.encoder.layers.\1.layer_norm1.bias",
|
||||
r"transformer.resblocks_(\d+).ln_2.weight": r"vision_model.encoder.layers.\1.layer_norm2.weight",
|
||||
r"transformer.resblocks_(\d+).ln_2.bias": r"vision_model.encoder.layers.\1.layer_norm2.bias",
|
||||
r"transformer.resblocks_(\d+).mlp.c_fc.weight": r"vision_model.encoder.layers.\1.mlp.fc1.weight",
|
||||
r"transformer.resblocks_(\d+).mlp.c_fc.bias": r"vision_model.encoder.layers.\1.mlp.fc1.bias",
|
||||
r"transformer.resblocks_(\d+).mlp.c_proj.weight": r"vision_model.encoder.layers.\1.mlp.fc2.weight",
|
||||
r"transformer.resblocks_(\d+).mlp.c_proj.bias": r"vision_model.encoder.layers.\1.mlp.fc2.bias",
|
||||
r"transformer.resblocks_(\d+).attn.(q|k|v|out)_proj.weight": r"vision_model.encoder.layers.\1.self_attn.\2_proj.weight",
|
||||
r"transformer.resblocks_(\d+).attn.(q|k|v|out)_proj.bias": r"vision_model.encoder.layers.\1.self_attn.\2_proj.bias",
|
||||
# Vision norm
|
||||
r"ln_post.weight": r"vision_model.post_layernorm.weight",
|
||||
r"ln_post.bias": r"vision_model.post_layernorm.bias",
|
||||
r"ln_pre.weight": r"vision_model.pre_layernorm.weight",
|
||||
r"ln_pre.bias": r"vision_model.pre_layernorm.bias",
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Model objects: configuration, image processor
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_mlcd_config(model_name: str) -> MLCDVisionConfig:
|
||||
"""
|
||||
Create a configuration for the MLCD model based on the model name.
|
||||
"""
|
||||
assert model_name in COMMON_CONFIG_PARAMS, f"Model {model_name} not found in the list of COMMON_CONFIG_PARAMS."
|
||||
config_params = COMMON_CONFIG_PARAMS[model_name]
|
||||
config = MLCDVisionConfig(
|
||||
hidden_size=config_params["hidden_size"],
|
||||
image_size=config_params["image_size"],
|
||||
intermediate_size=config_params["intermediate_size"],
|
||||
num_attention_heads=config_params["num_attention_heads"],
|
||||
num_hidden_layers=config_params["num_hidden_layers"],
|
||||
patch_size=config_params["patch_size"],
|
||||
projection_dim=config_params["projection_dim"],
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def get_mlcd_image_processor(model_name: str) -> CLIPImageProcessor:
|
||||
"""
|
||||
Create an image processor for the MLCD model based on the model name.
|
||||
"""
|
||||
assert model_name in COMMON_CONFIG_PARAMS, f"Model {model_name} not found in the list of COMMON_CONFIG_PARAMS."
|
||||
config_params = COMMON_CONFIG_PARAMS[model_name]
|
||||
image_processor = CLIPImageProcessor(
|
||||
do_center_crop=True,
|
||||
do_normalize=True,
|
||||
do_resize=True,
|
||||
feature_extractor_type="CLIPFeatureExtractor",
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
resample=3,
|
||||
size=config_params["image_size"],
|
||||
crop_size=config_params["image_size"],
|
||||
)
|
||||
return image_processor
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Helper functions for state dict conversion
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def flatten_nested_dict(params: dict, parent_key: str = "", sep: str = ".") -> dict:
|
||||
"""
|
||||
Flatten a nested original checkpoint dictionary into a flat dictionary.
|
||||
"""
|
||||
items = []
|
||||
for k, v in params.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if isinstance(v, collections.abc.MutableMapping):
|
||||
items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def split_resblocks_layers(state_dict: dict) -> dict:
|
||||
"""
|
||||
Split the resblocks weight into layers. In some cases they are concatenated in
|
||||
the original checkpoints.
|
||||
"""
|
||||
# Make shallow copy
|
||||
state_dict = state_dict.copy()
|
||||
# Split resblocks weight into layers
|
||||
keys = list(state_dict.keys())
|
||||
for key in keys:
|
||||
if ".resblocks." in key:
|
||||
weight = state_dict.pop(key)
|
||||
for i, weight_i in enumerate(weight):
|
||||
new_name = key.replace("resblocks", f"resblocks_{i}")
|
||||
state_dict[new_name] = weight_i
|
||||
return state_dict
|
||||
|
||||
|
||||
def chunk_qkv_for_attn(state_dict: dict) -> dict:
|
||||
"""
|
||||
Chunk the q/k/v weights and biases for the attention layers.
|
||||
"""
|
||||
# Make shallow copy
|
||||
state_dict = state_dict.copy()
|
||||
# Read and process q/k/v weights and biases
|
||||
keys = list(state_dict.keys())
|
||||
for key in keys:
|
||||
if ".in_proj." in key:
|
||||
weight = state_dict.pop(key)
|
||||
qkv_weights = weight.chunk(3, dim=0)
|
||||
for name, weight_i in zip(["q_proj", "k_proj", "v_proj"], qkv_weights):
|
||||
new_name = key.replace("in_proj", name)
|
||||
state_dict[new_name] = weight_i
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_old_keys_to_new_keys(state_dict_keys: list) -> dict:
|
||||
"""
|
||||
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||
the key mappings.
|
||||
"""
|
||||
output_dict = {}
|
||||
if state_dict_keys is not None:
|
||||
old_text = "\n".join(state_dict_keys)
|
||||
new_text = old_text
|
||||
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||
if replacement is None:
|
||||
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||
continue
|
||||
new_text = re.sub(pattern, replacement, new_text)
|
||||
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||
return output_dict
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Convert model
|
||||
# --------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_mlcd_checkpoint(model_name, input_dir, output_dir, verify_hidden_state=True, push_to_hub=False):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our MLCD structure.
|
||||
"""
|
||||
|
||||
# Define MLCD configuration
|
||||
config = get_mlcd_config(model_name)
|
||||
|
||||
checkpoint = MODEL_NAME_TO_CHECKPOINT_PATH[model_name]
|
||||
checkpoint_path = os.path.join(input_dir, checkpoint)
|
||||
assert os.path.exists(checkpoint_path), f"Checkpoint path ({checkpoint_path}) not found."
|
||||
|
||||
# Load original checkpoint
|
||||
print(f"Loading checkpoint from {checkpoint_path}...")
|
||||
state_dict = torch.load(checkpoint_path, "cpu")
|
||||
|
||||
# Flatten nested dictionary
|
||||
print("Flattening nested dictionary...")
|
||||
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
||||
if "positional_embedding" in state_dict:
|
||||
state_dict.pop("positional_embedding")
|
||||
state_dict = flatten_nested_dict(state_dict)
|
||||
state_dict = split_resblocks_layers(state_dict)
|
||||
state_dict = chunk_qkv_for_attn(state_dict)
|
||||
|
||||
# Rename and transform weights
|
||||
print("Renaming and transforming weights...")
|
||||
original_keys = list(state_dict.keys())
|
||||
hf_keys = convert_old_keys_to_new_keys(original_keys)
|
||||
new_state_dict = {}
|
||||
for original_key in original_keys:
|
||||
new_key = hf_keys[original_key]
|
||||
parameter = state_dict.pop(original_key)
|
||||
new_state_dict[new_key] = torch.from_numpy(parameter)
|
||||
|
||||
# load HuggingFace model
|
||||
print("Loading HuggingFace model...")
|
||||
model = MLCDVisionModel(config).eval()
|
||||
model.load_state_dict(new_state_dict)
|
||||
|
||||
# Create processor
|
||||
print("Creating processor...")
|
||||
image_processor = get_mlcd_image_processor(model_name)
|
||||
|
||||
# Verify hidden state
|
||||
if verify_hidden_state:
|
||||
print("Verifying hidden state for {model_name}...")
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
last_hidden_state = model(pixel_values, output_hidden_states=True).last_hidden_state[0, :5, :5]
|
||||
expected_hidden_state = EXPECTED_OUTPUTS[model_name]
|
||||
np.testing.assert_allclose(last_hidden_state.cpu().numpy(), expected_hidden_state.numpy(), atol=1e-4)
|
||||
|
||||
# Save model
|
||||
if output_dir is not None:
|
||||
dst_dir = os.path.join(output_dir, model_name)
|
||||
print(f"Saving model {model_name} to {dst_dir}...")
|
||||
model.save_pretrained(dst_dir)
|
||||
print(f"Saving processor to {dst_dir}...")
|
||||
image_processor.save_pretrained(dst_dir)
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing model and processor for {model_name} to the HuggingFace Hub...")
|
||||
model.push_to_hub(f"deepglint-hf/{model_name}", private=True)
|
||||
image_processor.push_to_hub(f"deepglint-hf/{model_name}", private=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="mlcd-vit-bigG-patch14-448",
|
||||
type=str,
|
||||
choices=MODEL_NAME_TO_CHECKPOINT_PATH.keys(),
|
||||
help="Name of the model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
default="mlcd/original",
|
||||
help="Location of MLCD original weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default="mlcd/checkpoint",
|
||||
help="Location to write HF model and processor",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verify_hidden_state",
|
||||
action="store_true",
|
||||
help="Whether to verify hidden_state against the original implementation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_mlcd_checkpoint(
|
||||
args.model_name, args.input_dir, args.output_dir, args.verify_hidden_state, args.push_to_hub
|
||||
)
|
679
src/transformers/models/mlcd/modeling_mlcd.py
Normal file
679
src/transformers/models/mlcd/modeling_mlcd.py
Normal file
@ -0,0 +1,679 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/mlcd/modular_mlcd.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_mlcd.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
can_return_tuple,
|
||||
logging,
|
||||
torch_int,
|
||||
)
|
||||
from .configuration_mlcd import MLCDVisionConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MLCDMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MLCDRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
|
||||
|
||||
Args:
|
||||
num_patches_height (int): Number of patches in the height dimension.
|
||||
num_patches_width (int): Number of patches in the width dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Rotary positional embeddings for the given grid size.
|
||||
"""
|
||||
# Generate position IDs for height and width dimensions
|
||||
hpos_ids = (
|
||||
torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
|
||||
)
|
||||
wpos_ids = (
|
||||
torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
|
||||
)
|
||||
|
||||
# Flatten and stack the position IDs
|
||||
pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
|
||||
|
||||
# Generate the full rotary positional embeddings for the maximum grid size
|
||||
max_grid_size = max(num_patches_height, num_patches_width)
|
||||
seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
|
||||
|
||||
# Select and flatten the embeddings based on the position IDs
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
|
||||
return rotary_pos_emb
|
||||
|
||||
|
||||
class MLCDVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
position_embedding = self.position_embedding.weight.unsqueeze(0)
|
||||
num_positions = position_embedding.shape[1] - 1
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embedding(self.position_ids)
|
||||
|
||||
class_pos_embed = position_embedding[:, :1]
|
||||
patch_pos_embed = position_embedding[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
# patch_embeds -> shape = [batch, width, grid, grid]
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_q_dtype = q.dtype
|
||||
orig_k_dtype = k.dtype
|
||||
q, k = q.float(), k.float()
|
||||
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
q_embed = q_embed.to(orig_q_dtype)
|
||||
k_embed = k_embed.to(orig_k_dtype)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class MLCDAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper
|
||||
Multi-headed attention with RoPE. Refer to papers:
|
||||
- Attention is all you need:
|
||||
https://arxiv.org/abs/1706.03762
|
||||
- RoFormer: Enhanced Transformer with Rotary Position Embedding:
|
||||
https://arxiv.org/abs/2104.09864
|
||||
"""
|
||||
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.num_key_value_groups = config.num_key_value_groups
|
||||
self.is_causal = False
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
|
||||
# Each of shape: [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
|
||||
key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
|
||||
value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
|
||||
|
||||
# Apply positional embeddings
|
||||
cos = position_embeddings[0].unsqueeze(0).float()
|
||||
sin = position_embeddings[1].unsqueeze(0).float()
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
# Each of shape: [batch_size, num_heads, seq_length, head_dim]
|
||||
query_states = query_states.permute(0, 2, 1, 3).contiguous()
|
||||
key_states = key_states.permute(0, 2, 1, 3).contiguous()
|
||||
value_states = value_states.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
|
||||
attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class MLCDEncoderLayer(nn.Module):
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = MLCDAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = MLCDMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||
Represents the hidden states from the previous layer or the input embeddings.
|
||||
position_embeddings (`Tuple[torch.Tensor, torch.Tensor]`):
|
||||
A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
|
||||
Represents absolute positional embeddings for the query and key in the attention mechanism.
|
||||
attention_mask (`torch.FloatTensor`):
|
||||
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MLCDEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`MLCDEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: MLCDVisionConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
"""Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
position_embeddings (`Tuple[torch.Tensor, torch.Tensor]`):
|
||||
A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
|
||||
Represents absolute positional embeddings for the query and key in the attention mechanism.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
MLCD_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class MLCDVisionTransformer(nn.Module):
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = MLCDVisionEmbeddings(config)
|
||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = MLCDEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
|
||||
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
|
||||
|
||||
@add_start_docstrings_to_model_forward(MLCD_VISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
num_patches_height = pixel_values.shape[-2] // self.config.patch_size
|
||||
num_patches_width = pixel_values.shape[-1] // self.config.patch_size
|
||||
rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
|
||||
rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
|
||||
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class MLCDPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = MLCDVisionConfig
|
||||
base_model_prefix = "mlcd"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, MLCDVisionEmbeddings):
|
||||
factor = self.config.initializer_factor
|
||||
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||||
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
||||
elif isinstance(module, MLCDAttention):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
out_proj_std = (module.embed_dim**-0.5) * factor
|
||||
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
||||
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
||||
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
||||
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
||||
elif isinstance(module, MLCDMLP):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
||||
nn.init.normal_(module.fc1.weight, std=fc_std)
|
||||
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
||||
elif isinstance(module, MLCDVisionTransformer):
|
||||
factor = self.config.initializer_factor
|
||||
pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
|
||||
nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
MLCD_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`MLCDVisionConfig`]):
|
||||
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare MLCD vision encoder outputting raw hidden-states without any specific head on top.",
|
||||
MLCD_START_DOCSTRING,
|
||||
)
|
||||
class MLCDVisionModel(MLCDPreTrainedModel):
|
||||
config_class = MLCDVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = ["MLCDEncoderLayer"]
|
||||
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_model = MLCDVisionTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
@add_start_docstrings_to_model_forward(MLCD_VISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import AutoProcessor, MLCDVisionModel
|
||||
>>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
|
||||
>>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs, output_attentions=True)
|
||||
|
||||
>>> features = outputs.last_hidden_state
|
||||
>>> print(f"Extracted features shape: {features.shape}")
|
||||
>>> print(f"Number of attention layers: {len(outputs.attentions)}")
|
||||
>>> print(f"Attention shape: {outputs.attentions[0].shape}")
|
||||
```"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["MLCDPreTrainedModel", "MLCDVisionModel"]
|
596
src/transformers/models/mlcd/modular_mlcd.py
Normal file
596
src/transformers/models/mlcd/modular_mlcd.py
Normal file
@ -0,0 +1,596 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
)
|
||||
from ...modeling_utils import (
|
||||
ALL_ATTENTION_FUNCTIONS,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
)
|
||||
from ..clip.modeling_clip import (
|
||||
CLIPMLP,
|
||||
CLIPAttention,
|
||||
CLIPEncoder,
|
||||
CLIPEncoderLayer,
|
||||
CLIPVisionEmbeddings,
|
||||
CLIPVisionModel,
|
||||
CLIPVisionTransformer,
|
||||
)
|
||||
from ..llama.modeling_llama import eager_attention_forward
|
||||
from ..qwen2_vl.modeling_qwen2_vl import (
|
||||
VisionRotaryEmbedding,
|
||||
apply_rotary_pos_emb_vision,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MLCDVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MLCDVisionModel`]. It is used to instantiate a MLCD
|
||||
vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the vision encoder of the MLCD
|
||||
[DeepGlint-AI/mlcd-vit-bigG-patch14-336](https://huggingface.co/DeepGlint-AI/mlcd-vit-bigG-patch14-336) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1664):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of text and vision projection layers.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
image_size (`int`, *optional*, defaults to 336):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float`, *optional*, defaults to 1.0):
|
||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
||||
testing).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import MLCDVisionConfig, MLCDVisionModel
|
||||
|
||||
>>> # Initializing a MLCDVisionConfig with DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
|
||||
>>> configuration = MLCDVisionConfig()
|
||||
|
||||
>>> # Initializing a MLCDVisionModel (with random weights) from the DeepGlint-AI/mlcd-vit-bigG-patch14-336 style configuration
|
||||
>>> model = MLCDVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "mlcd_vision_model"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1664,
|
||||
intermediate_size=8192,
|
||||
num_hidden_layers=48,
|
||||
num_attention_heads=16,
|
||||
num_key_value_groups=1,
|
||||
num_channels=3,
|
||||
image_size=336,
|
||||
patch_size=14,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-5,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_groups = num_key_value_groups
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
|
||||
class MLCDMLP(CLIPMLP):
|
||||
pass
|
||||
|
||||
|
||||
class MLCDRotaryEmbedding(VisionRotaryEmbedding):
|
||||
def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the Rotary Position Embedding (RoPE) for MLCDVisionModel based on the grid size.
|
||||
|
||||
Args:
|
||||
num_patches_height (int): Number of patches in the height dimension.
|
||||
num_patches_width (int): Number of patches in the width dimension.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Rotary positional embeddings for the given grid size.
|
||||
"""
|
||||
# Generate position IDs for height and width dimensions
|
||||
hpos_ids = (
|
||||
torch.arange(num_patches_height, device=self.inv_freq.device).unsqueeze(1).expand(-1, num_patches_width)
|
||||
)
|
||||
wpos_ids = (
|
||||
torch.arange(num_patches_width, device=self.inv_freq.device).unsqueeze(0).expand(num_patches_height, -1)
|
||||
)
|
||||
|
||||
# Flatten and stack the position IDs
|
||||
pos_ids = torch.stack([hpos_ids.flatten(), wpos_ids.flatten()], dim=-1)
|
||||
|
||||
# Generate the full rotary positional embeddings for the maximum grid size
|
||||
max_grid_size = max(num_patches_height, num_patches_width)
|
||||
seq = torch.arange(max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
|
||||
|
||||
# Select and flatten the embeddings based on the position IDs
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
|
||||
return rotary_pos_emb
|
||||
|
||||
|
||||
class MLCDVisionEmbeddings(CLIPVisionEmbeddings):
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__(config)
|
||||
del self.position_embedding
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
# patch_embeds -> shape = [batch, width, grid, grid]
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class MLCDAttention(CLIPAttention):
|
||||
"""Multi-headed attention with RoPE. Refer to papers:
|
||||
- Attention is all you need:
|
||||
https://arxiv.org/abs/1706.03762
|
||||
- RoFormer: Enhanced Transformer with Rotary Position Embedding:
|
||||
https://arxiv.org/abs/2104.09864
|
||||
"""
|
||||
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__(config)
|
||||
self.num_key_value_groups = config.num_key_value_groups
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
|
||||
# Each of shape: [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
|
||||
key_states = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
|
||||
value_states = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
|
||||
|
||||
# Apply positional embeddings
|
||||
cos = position_embeddings[0].unsqueeze(0).float()
|
||||
sin = position_embeddings[1].unsqueeze(0).float()
|
||||
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||
|
||||
# Each of shape: [batch_size, num_heads, seq_length, head_dim]
|
||||
query_states = query_states.permute(0, 2, 1, 3).contiguous()
|
||||
key_states = key_states.permute(0, 2, 1, 3).contiguous()
|
||||
value_states = value_states.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.permute(1, 0, 2, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
|
||||
attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class MLCDEncoderLayer(CLIPEncoderLayer):
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__(config)
|
||||
self.self_attn = MLCDAttention(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input to the layer of shape `(batch, seq_len, embed_dim)`.
|
||||
Represents the hidden states from the previous layer or the input embeddings.
|
||||
position_embeddings (`Tuple[torch.Tensor, torch.Tensor]`):
|
||||
A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
|
||||
Represents absolute positional embeddings for the query and key in the attention mechanism.
|
||||
attention_mask (`torch.FloatTensor`):
|
||||
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MLCDEncoder(CLIPEncoder):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`MLCDEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: MLCDVisionConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
"""Overwrite dummy `MLCDConfig` to `MLCDVisionConfig`."""
|
||||
super().__init__(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
position_embeddings (`Tuple[torch.Tensor, torch.Tensor]`):
|
||||
A tuple of two tensors, each of shape `(batch, seq_len, embed_dim)`.
|
||||
Represents absolute positional embeddings for the query and key in the attention mechanism.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
attention_mask,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=encoder_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
MLCD_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class MLCDVisionTransformer(CLIPVisionTransformer):
|
||||
def __init__(self, config: MLCDVisionConfig):
|
||||
super().__init__(config)
|
||||
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
|
||||
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
|
||||
|
||||
@add_start_docstrings_to_model_forward(MLCD_VISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
num_patches_height = pixel_values.shape[-2] // self.config.patch_size
|
||||
num_patches_width = pixel_values.shape[-1] // self.config.patch_size
|
||||
rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)
|
||||
rotary_pos_emb = rotary_pos_emb.to(self.class_pos_emb.device)
|
||||
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
MLCD_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`MLCDVisionConfig`]):
|
||||
Model configuration class with all the parameters of the vision encoder. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
class MLCDPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = MLCDVisionConfig
|
||||
base_model_prefix = "mlcd"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
factor = self.config.initializer_factor
|
||||
if isinstance(module, MLCDVisionEmbeddings):
|
||||
factor = self.config.initializer_factor
|
||||
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
||||
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
||||
elif isinstance(module, MLCDAttention):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
out_proj_std = (module.embed_dim**-0.5) * factor
|
||||
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
||||
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
||||
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
||||
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
||||
elif isinstance(module, MLCDMLP):
|
||||
factor = self.config.initializer_factor
|
||||
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
||||
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
||||
nn.init.normal_(module.fc1.weight, std=fc_std)
|
||||
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
||||
elif isinstance(module, MLCDVisionTransformer):
|
||||
factor = self.config.initializer_factor
|
||||
pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor
|
||||
nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare MLCD vision encoder outputting raw hidden-states without any specific head on top.",
|
||||
MLCD_START_DOCSTRING,
|
||||
)
|
||||
class MLCDVisionModel(CLIPVisionModel):
|
||||
@add_start_docstrings_to_model_forward(MLCD_VISION_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
```python
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import AutoProcessor, MLCDVisionModel
|
||||
>>> model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
|
||||
>>> processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-448")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs, output_attentions=True)
|
||||
|
||||
>>> features = outputs.last_hidden_state
|
||||
>>> print(f"Extracted features shape: {features.shape}")
|
||||
>>> print(f"Number of attention layers: {len(outputs.attentions)}")
|
||||
>>> print(f"Attention shape: {outputs.attentions[0].shape}")
|
||||
```"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MLCDVisionConfig",
|
||||
"MLCDPreTrainedModel",
|
||||
"MLCDVisionModel",
|
||||
]
|
0
tests/models/mlcd/__init__.py
Normal file
0
tests/models/mlcd/__init__.py
Normal file
221
tests/models/mlcd/test_modeling_mlcd.py
Normal file
221
tests/models/mlcd/test_modeling_mlcd.py
Normal file
@ -0,0 +1,221 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch MLCD model."""
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
MLCDVisionConfig,
|
||||
MLCDVisionModel,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class MLCDVisionModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=12,
|
||||
image_size=30,
|
||||
patch_size=2,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
# in MLCD, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
self.seq_length = num_patches + 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return MLCDVisionConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=self.attention_dropout,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = MLCDVisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
||||
image_size = (self.image_size, self.image_size)
|
||||
patch_size = (self.patch_size, self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class MLCDVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Model tester for `MLCDVisionModel`.
|
||||
"""
|
||||
|
||||
all_model_classes = (MLCDVisionModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MLCDVisionModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=MLCDVisionConfig, has_text_modality=False)
|
||||
|
||||
def test_model_get_set_embeddings(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad and "class_pos_emb" not in name:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class MLCDVisionModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference(self):
|
||||
model_name = "DeepGlint-AI/mlcd-vit-bigG-patch14-448"
|
||||
model = MLCDVisionModel.from_pretrained(model_name).to(torch_device)
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
# process single image
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
# move inputs to the same device as the model
|
||||
inputs = {k: v.to(torch_device) for k, v in inputs.items()}
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs, output_attentions=True)
|
||||
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_attention = outputs.attentions[-1]
|
||||
|
||||
# verify the shapes of last_hidden_state and last_attention
|
||||
self.assertEqual(
|
||||
last_hidden_state.shape,
|
||||
torch.Size([1, 1025, 1664]),
|
||||
)
|
||||
self.assertEqual(
|
||||
last_attention.shape,
|
||||
torch.Size([1, 16, 1025, 1025]),
|
||||
)
|
||||
|
||||
# verify the values of last_hidden_state and last_attention
|
||||
# fmt: off
|
||||
expected_partial_5x5_last_hidden_state = torch.tensor(
|
||||
[
|
||||
[-0.8978, -0.1181, 0.4769, 0.4761, -0.5779],
|
||||
[ 0.2640, -2.6150, 0.4853, 0.5743, -1.1003],
|
||||
[ 0.3314, -0.3328, -0.4305, -0.1874, -0.7701],
|
||||
[-1.5174, -1.0238, -1.1854, 0.1749, -0.8786],
|
||||
[ 0.2323, -0.8346, -0.9680, -0.2951, 0.0867],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
expected_partial_5x5_last_attention = torch.tensor(
|
||||
[
|
||||
[2.0930e-01, 6.3073e-05, 1.4717e-03, 2.6881e-05, 3.0513e-03],
|
||||
[1.5828e-04, 2.1056e-03, 4.6784e-04, 1.8276e-03, 5.3233e-04],
|
||||
[5.7824e-04, 1.1446e-03, 1.3854e-03, 1.1775e-03, 1.2750e-03],
|
||||
[9.6343e-05, 1.6365e-03, 2.9066e-04, 3.1089e-03, 2.0607e-04],
|
||||
[6.2688e-04, 1.1656e-03, 1.5030e-03, 8.2819e-04, 2.6992e-03],
|
||||
]
|
||||
).to(torch_device)
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(
|
||||
last_hidden_state[0, :5, :5], expected_partial_5x5_last_hidden_state, rtol=1e-3, atol=1e-3
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
last_attention[0, 0, :5, :5], expected_partial_5x5_last_attention, rtol=1e-4, atol=1e-4
|
||||
)
|
@ -383,6 +383,7 @@ OBJECTS_TO_IGNORE = [
|
||||
"MegatronBertConfig",
|
||||
"MegatronBertForPreTraining",
|
||||
"MegatronBertModel",
|
||||
"MLCDVisionConfig",
|
||||
"MobileBertConfig",
|
||||
"MobileBertModel",
|
||||
"MobileBertTokenizerFast",
|
||||
|
Loading…
Reference in New Issue
Block a user