mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Add BeitBackbone (#25952)
* First draft * Add backwards compatibility * More improvements * More improvements * Improve error message * Address comment * Add conversion script * Fix style * Update code snippet * Adddress comment * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
7a757bb694
commit
1fb3c23b41
@ -1261,6 +1261,7 @@ else:
|
||||
_import_structure["models.beit"].extend(
|
||||
[
|
||||
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"BeitBackbone",
|
||||
"BeitForImageClassification",
|
||||
"BeitForMaskedImageModeling",
|
||||
"BeitForSemanticSegmentation",
|
||||
@ -5414,6 +5415,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.beit import (
|
||||
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BeitBackbone,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
|
@ -1080,6 +1080,7 @@ MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Backbone mapping
|
||||
("beit", "BeitBackbone"),
|
||||
("bit", "BitBackbone"),
|
||||
("convnext", "ConvNextBackbone"),
|
||||
("convnextv2", "ConvNextV2Backbone"),
|
||||
|
@ -47,6 +47,7 @@ else:
|
||||
"BeitForSemanticSegmentation",
|
||||
"BeitModel",
|
||||
"BeitPreTrainedModel",
|
||||
"BeitBackbone",
|
||||
]
|
||||
|
||||
|
||||
@ -83,6 +84,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .modeling_beit import (
|
||||
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BeitBackbone,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
|
@ -21,6 +21,7 @@ from packaging import version
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -33,7 +34,7 @@ BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class BeitConfig(PretrainedConfig):
|
||||
class BeitConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`BeitModel`]. It is used to instantiate an BEiT
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
@ -84,8 +85,6 @@ class BeitConfig(PretrainedConfig):
|
||||
use_mean_pooling (`bool`, *optional*, defaults to `True`):
|
||||
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
||||
CLS token, before applying the classification head.
|
||||
out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
|
||||
Indices of the feature maps to use for semantic segmentation.
|
||||
pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
|
||||
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
|
||||
use_auxiliary_head (`bool`, *optional*, defaults to `True`):
|
||||
@ -100,6 +99,20 @@ class BeitConfig(PretrainedConfig):
|
||||
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
||||
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
||||
The index that is ignored by the loss function of the semantic segmentation model.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
||||
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
|
||||
out_indices (`List[int]`, *optional*):
|
||||
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
||||
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
||||
If unset and `out_features` is unset, will default to the last stage.
|
||||
add_fpn (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a FPN as part of the backbone. Only relevant for [`BeitBackbone`].
|
||||
reshape_hidden_states (`bool`, *optional*, defaults to `True`):
|
||||
Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
|
||||
case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
|
||||
seq_len, hidden_size)`. Only relevant for [`BeitBackbone`].
|
||||
|
||||
Example:
|
||||
|
||||
@ -140,7 +153,6 @@ class BeitConfig(PretrainedConfig):
|
||||
layer_scale_init_value=0.1,
|
||||
drop_path_rate=0.1,
|
||||
use_mean_pooling=True,
|
||||
out_indices=[3, 5, 7, 11],
|
||||
pool_scales=[1, 2, 3, 6],
|
||||
use_auxiliary_head=True,
|
||||
auxiliary_loss_weight=0.4,
|
||||
@ -148,6 +160,10 @@ class BeitConfig(PretrainedConfig):
|
||||
auxiliary_num_convs=1,
|
||||
auxiliary_concat_input=False,
|
||||
semantic_loss_ignore_index=255,
|
||||
out_features=None,
|
||||
out_indices=None,
|
||||
add_fpn=False,
|
||||
reshape_hidden_states=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -174,7 +190,6 @@ class BeitConfig(PretrainedConfig):
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.use_mean_pooling = use_mean_pooling
|
||||
# decode head attributes (semantic segmentation)
|
||||
self.out_indices = out_indices
|
||||
self.pool_scales = pool_scales
|
||||
# auxiliary head attributes (semantic segmentation)
|
||||
self.use_auxiliary_head = use_auxiliary_head
|
||||
@ -184,6 +199,22 @@ class BeitConfig(PretrainedConfig):
|
||||
self.auxiliary_concat_input = auxiliary_concat_input
|
||||
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
||||
|
||||
# handle backwards compatibility
|
||||
if "segmentation_indices" in kwargs:
|
||||
logger.warning(
|
||||
"The `segmentation_indices` argument is deprecated and will be removed in a future version, use `out_indices` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
out_indices = kwargs.pop("segmentation_indices")
|
||||
|
||||
# backbone attributes
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)]
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
self.add_fpn = add_fpn
|
||||
self.reshape_hidden_states = reshape_hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.configuration_vit.ViTOnnxConfig
|
||||
class BeitOnnxConfig(OnnxConfig):
|
||||
|
@ -22,11 +22,12 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import (
|
||||
BackboneOutput,
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
@ -42,6 +43,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_beit import BeitConfig
|
||||
|
||||
|
||||
@ -149,22 +151,26 @@ class BeitEmbeddings(nn.Module):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
||||
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
||||
)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
if bool_masked_pos is not None:
|
||||
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
||||
# replace the masked visual tokens by mask_tokens
|
||||
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||||
embeddings = embeddings * (1 - w) + mask_tokens * w
|
||||
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
||||
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
return embeddings, (patch_height, patch_width)
|
||||
|
||||
|
||||
class BeitPatchEmbeddings(nn.Module):
|
||||
@ -191,19 +197,29 @@ class BeitPatchEmbeddings(nn.Module):
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings
|
||||
embeddings = self.projection(pixel_values)
|
||||
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
|
||||
|
||||
if position_embedding is not None:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(
|
||||
0, 3, 1, 2
|
||||
)
|
||||
position_embedding = nn.functional.interpolate(
|
||||
position_embedding, size=(patch_height, patch_width), mode="bicubic"
|
||||
)
|
||||
embeddings = embeddings + position_embedding
|
||||
|
||||
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings, (patch_height, patch_width)
|
||||
|
||||
|
||||
class BeitSelfAttention(nn.Module):
|
||||
@ -669,7 +685,7 @@ class BeitModel(BeitPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(pixel_values, bool_masked_pos)
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -1151,6 +1167,12 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
self.beit = BeitModel(config, add_pooling_layer=False)
|
||||
|
||||
# FPNs
|
||||
if len(self.config.out_indices) != 4:
|
||||
raise ValueError(
|
||||
"BeitForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
|
||||
"specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
|
||||
"a base-sized architecture."
|
||||
)
|
||||
self.fpn1 = nn.Sequential(
|
||||
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
||||
nn.BatchNorm2d(config.hidden_size),
|
||||
@ -1280,3 +1302,126 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
BEiT backbone, to be used with frameworks like DETR and MaskFormer.
|
||||
""",
|
||||
BEIT_START_DOCSTRING,
|
||||
)
|
||||
class BeitBackbone(BeitPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
|
||||
self.embeddings = BeitEmbeddings(config)
|
||||
self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
|
||||
|
||||
if config.add_fpn:
|
||||
if len(self.config.out_indices) != 4:
|
||||
raise ValueError(
|
||||
"BeitBackbone requires config.out_indices to be a list of 4 integers, "
|
||||
"specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
|
||||
"a base-sized architecture."
|
||||
)
|
||||
hidden_size = config.hidden_size
|
||||
self.fpn1 = nn.Sequential(
|
||||
nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
|
||||
nn.BatchNorm2d(hidden_size, eps=config.batch_norm_eps),
|
||||
nn.GELU(),
|
||||
nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
self.fpn2 = nn.Sequential(nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2))
|
||||
self.fpn3 = nn.Identity()
|
||||
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Tensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> BackboneOutput:
|
||||
"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoImageProcessor, AutoBackbone
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
|
||||
>>> model = AutoBackbone.from_pretrained(
|
||||
... "microsoft/beit-base-patch16-224", out_features=["stage1", "stage2", "stage3", "stage4"]
|
||||
... )
|
||||
|
||||
>>> inputs = processor(image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> feature_maps = outputs.feature_maps
|
||||
>>> list(feature_maps[-1].shape)
|
||||
[1, 768, 14, 14]
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
batch_size = pixel_values.shape[0]
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values)
|
||||
|
||||
outputs = self.encoder(
|
||||
embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
||||
|
||||
feature_maps = ()
|
||||
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
||||
if stage in self.out_features:
|
||||
if self.config.reshape_hidden_states:
|
||||
hidden_state = hidden_state[:, 1:, :]
|
||||
hidden_state = hidden_state.permute(0, 2, 1)
|
||||
hidden_state = hidden_state.reshape(batch_size, -1, patch_height, patch_width)
|
||||
|
||||
feature_maps += (hidden_state,)
|
||||
|
||||
if self.config.add_fpn:
|
||||
feature_maps = [
|
||||
self.fpn1(feature_maps[0]),
|
||||
self.fpn2(feature_maps[1]),
|
||||
self.fpn3(feature_maps[2]),
|
||||
self.fpn4(feature_maps[3]),
|
||||
]
|
||||
feature_maps = tuple(feature_maps)
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
output = (feature_maps,) + outputs[1:]
|
||||
else:
|
||||
output = (feature_maps,) + outputs[2:]
|
||||
return output
|
||||
|
||||
return BackboneOutput(
|
||||
feature_maps=feature_maps,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -150,22 +150,26 @@ class Data2VecVisionEmbeddings(nn.Module):
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
||||
embeddings = self.patch_embeddings(pixel_values)
|
||||
embeddings, (patch_height, patch_width) = self.patch_embeddings(
|
||||
pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
|
||||
)
|
||||
batch_size, seq_len, _ = embeddings.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
if bool_masked_pos is not None:
|
||||
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
||||
# replace the masked visual tokens by mask_tokens
|
||||
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
||||
embeddings = embeddings * (1 - w) + mask_tokens * w
|
||||
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
if self.position_embeddings is not None:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
|
||||
|
||||
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
return embeddings, (patch_height, patch_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision
|
||||
@ -193,19 +197,29 @@ class Data2VecVisionPatchEmbeddings(nn.Module):
|
||||
|
||||
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, pixel_values: torch.Tensor, position_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = pixel_values.shape
|
||||
if num_channels != self.num_channels:
|
||||
raise ValueError(
|
||||
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
||||
)
|
||||
if height != self.image_size[0] or width != self.image_size[1]:
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
||||
)
|
||||
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings
|
||||
embeddings = self.projection(pixel_values)
|
||||
patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
|
||||
|
||||
if position_embedding is not None:
|
||||
# interpolate the position embedding to the corresponding size
|
||||
position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(
|
||||
0, 3, 1, 2
|
||||
)
|
||||
position_embedding = nn.functional.interpolate(
|
||||
position_embedding, size=(patch_height, patch_width), mode="bicubic"
|
||||
)
|
||||
embeddings = embeddings + position_embedding
|
||||
|
||||
embeddings = embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
return embeddings, (patch_height, patch_width)
|
||||
|
||||
|
||||
# Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
|
||||
@ -683,7 +697,7 @@ class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
embedding_output = self.embeddings(pixel_values, bool_masked_pos)
|
||||
embedding_output, (patch_height, patch_width) = self.embeddings(pixel_values, bool_masked_pos)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@ -1079,6 +1093,12 @@ class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
|
||||
self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=False)
|
||||
|
||||
# FPNs
|
||||
if len(self.config.out_indices) != 4:
|
||||
raise ValueError(
|
||||
"Data2VecVisionForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
|
||||
"specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
|
||||
"a base-sized architecture."
|
||||
)
|
||||
self.fpn1 = nn.Sequential(
|
||||
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
||||
nn.BatchNorm2d(config.hidden_size),
|
||||
|
305
src/transformers/models/dpt/convert_dpt_beit_to_hf.py
Normal file
305
src/transformers/models/dpt/convert_dpt_beit_to_hf.py
Normal file
@ -0,0 +1,305 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert DPT 3.1 checkpoints from the MiDaS repository. URL: https://github.com/isl-org/MiDaS"""
|
||||
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from transformers import BeitConfig, DPTConfig, DPTForDepthEstimation, DPTImageProcessor
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_dpt_config(model_name):
|
||||
hidden_size = 768
|
||||
num_hidden_layers = 12
|
||||
num_attention_heads = 12
|
||||
intermediate_size = 3072
|
||||
out_features = ["stage3", "stage6", "stage9", "stage12"] # beit-base-384 uses [2, 5, 8, 11]
|
||||
|
||||
if "large" in model_name:
|
||||
hidden_size = 1024
|
||||
num_hidden_layers = 24
|
||||
num_attention_heads = 16
|
||||
intermediate_size = 4096
|
||||
out_features = ["stage6", "stage12", "stage18", "stage24"] # beit-large-512 uses [5, 11, 17, 23]
|
||||
|
||||
if "512" in model_name:
|
||||
image_size = 512
|
||||
elif "384" in model_name:
|
||||
image_size = 384
|
||||
else:
|
||||
raise ValueError("Model not supported")
|
||||
|
||||
backbone_config = BeitConfig(
|
||||
image_size=image_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_relative_position_bias=True,
|
||||
reshape_hidden_states=False,
|
||||
out_features=out_features,
|
||||
)
|
||||
|
||||
neck_hidden_sizes = [256, 512, 1024, 1024] if "large" in model_name else [96, 192, 384, 768]
|
||||
config = DPTConfig(backbone_config=backbone_config, neck_hidden_sizes=neck_hidden_sizes)
|
||||
|
||||
return config, image_size
|
||||
|
||||
|
||||
# here we list all keys to be renamed (original name on the left, our name on the right)
|
||||
def create_rename_keys(config):
|
||||
rename_keys = []
|
||||
|
||||
# fmt: off
|
||||
# stem
|
||||
rename_keys.append(("pretrained.model.cls_token", "backbone.embeddings.cls_token"))
|
||||
rename_keys.append(("pretrained.model.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight"))
|
||||
rename_keys.append(("pretrained.model.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias"))
|
||||
|
||||
# Transfomer encoder
|
||||
for i in range(config.backbone_config.num_hidden_layers):
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.gamma_1", f"backbone.encoder.layer.{i}.lambda_1"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.gamma_2", f"backbone.encoder.layer.{i}.lambda_2"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.norm1.weight", f"backbone.encoder.layer.{i}.layernorm_before.weight"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.norm1.bias", f"backbone.encoder.layer.{i}.layernorm_before.bias"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.norm2.weight", f"backbone.encoder.layer.{i}.layernorm_after.weight"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.norm2.bias", f"backbone.encoder.layer.{i}.layernorm_after.bias"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc1.weight", f"backbone.encoder.layer.{i}.intermediate.dense.weight"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc1.bias", f"backbone.encoder.layer.{i}.intermediate.dense.bias"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc2.weight", f"backbone.encoder.layer.{i}.output.dense.weight"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc2.bias", f"backbone.encoder.layer.{i}.output.dense.bias"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.attn.proj.weight", f"backbone.encoder.layer.{i}.attention.output.dense.weight"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.attn.proj.bias", f"backbone.encoder.layer.{i}.attention.output.dense.bias"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.attn.relative_position_bias_table", f"backbone.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"))
|
||||
rename_keys.append((f"pretrained.model.blocks.{i}.attn.relative_position_index", f"backbone.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"))
|
||||
|
||||
# activation postprocessing (readout projections + resize blocks)
|
||||
for i in range(4):
|
||||
rename_keys.append((f"pretrained.act_postprocess{i+1}.0.project.0.weight", f"neck.reassemble_stage.readout_projects.{i}.0.weight"))
|
||||
rename_keys.append((f"pretrained.act_postprocess{i+1}.0.project.0.bias", f"neck.reassemble_stage.readout_projects.{i}.0.bias"))
|
||||
|
||||
rename_keys.append((f"pretrained.act_postprocess{i+1}.3.weight", f"neck.reassemble_stage.layers.{i}.projection.weight"))
|
||||
rename_keys.append((f"pretrained.act_postprocess{i+1}.3.bias", f"neck.reassemble_stage.layers.{i}.projection.bias"))
|
||||
|
||||
if i != 2:
|
||||
rename_keys.append((f"pretrained.act_postprocess{i+1}.4.weight", f"neck.reassemble_stage.layers.{i}.resize.weight"))
|
||||
rename_keys.append((f"pretrained.act_postprocess{i+1}.4.bias", f"neck.reassemble_stage.layers.{i}.resize.bias"))
|
||||
|
||||
# refinenet (tricky here)
|
||||
mapping = {1:3, 2:2, 3:1, 4:0}
|
||||
|
||||
for i in range(1, 5):
|
||||
j = mapping[i]
|
||||
rename_keys.append((f"scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight"))
|
||||
rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias"))
|
||||
|
||||
# scratch convolutions
|
||||
for i in range(4):
|
||||
rename_keys.append((f"scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight"))
|
||||
|
||||
# head
|
||||
for i in range(0, 5, 2):
|
||||
rename_keys.append((f"scratch.output_conv.{i}.weight", f"head.head.{i}.weight"))
|
||||
rename_keys.append((f"scratch.output_conv.{i}.bias", f"head.head.{i}.bias"))
|
||||
|
||||
return rename_keys
|
||||
|
||||
|
||||
def remove_ignore_keys_(state_dict):
|
||||
ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"]
|
||||
for k in ignore_keys:
|
||||
state_dict.pop(k, None)
|
||||
|
||||
|
||||
# we split up the matrix of each encoder layer into queries, keys and values
|
||||
def read_in_q_k_v(state_dict, config):
|
||||
hidden_size = config.backbone_config.hidden_size
|
||||
for i in range(config.backbone_config.num_hidden_layers):
|
||||
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
|
||||
in_proj_weight = state_dict.pop(f"pretrained.model.blocks.{i}.attn.qkv.weight")
|
||||
q_bias = state_dict.pop(f"pretrained.model.blocks.{i}.attn.q_bias")
|
||||
v_bias = state_dict.pop(f"pretrained.model.blocks.{i}.attn.v_bias")
|
||||
# next, add query, keys and values (in that order) to the state dict
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :]
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
||||
hidden_size : hidden_size * 2, :
|
||||
]
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :]
|
||||
state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
|
||||
|
||||
|
||||
def rename_key(dct, old, new):
|
||||
val = dct.pop(old)
|
||||
dct[new] = val
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
im = Image.open(requests.get(url, stream=True).raw)
|
||||
return im
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to our DPT structure.
|
||||
"""
|
||||
|
||||
name_to_url = {
|
||||
"dpt-beit-large-512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
|
||||
"dpt-beit-large-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt",
|
||||
"dpt-beit-base-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt",
|
||||
}
|
||||
|
||||
# define DPT configuration based on URL
|
||||
checkpoint_url = name_to_url[model_name]
|
||||
config, image_size = get_dpt_config(model_name)
|
||||
# load original state_dict from URL
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
|
||||
# remove certain keys
|
||||
remove_ignore_keys_(state_dict)
|
||||
# rename keys
|
||||
rename_keys = create_rename_keys(config)
|
||||
for src, dest in rename_keys:
|
||||
rename_key(state_dict, src, dest)
|
||||
# read in qkv matrices
|
||||
read_in_q_k_v(state_dict, config)
|
||||
|
||||
# load HuggingFace model
|
||||
model = DPTForDepthEstimation(config)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
print("Missing keys:", missing_keys)
|
||||
print("Unexpected keys:", unexpected_keys)
|
||||
assert missing_keys == []
|
||||
# assert unexpected_keys == ["pretrained.model.fc_norm.weight", "pretrained.model.fc_norm.bias"]
|
||||
model.eval()
|
||||
|
||||
# Check outputs on an image
|
||||
processor = DPTImageProcessor(
|
||||
size={"height": image_size, "width": image_size}, keep_aspect_ratio=True, ensure_multiple_of=32
|
||||
)
|
||||
|
||||
image = prepare_img()
|
||||
pixel_values = processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
print("First values of pixel values:", pixel_values[0, 0, :3, :3])
|
||||
print("Mean of pixel values:", pixel_values.mean().item())
|
||||
print("Shape of pixel values:", pixel_values.shape)
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((image_size, image_size)),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
pixel_values = transforms(image).unsqueeze(0)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
|
||||
predicted_depth = outputs.predicted_depth
|
||||
|
||||
print("Shape of predicted depth:", predicted_depth.shape)
|
||||
print("First values of predicted depth:", predicted_depth[0, :3, :3])
|
||||
|
||||
# assert logits
|
||||
# TODO there's still a small difference with the original logits
|
||||
if model_name == "dpt-beit-large-512":
|
||||
# OK, checked
|
||||
expected_shape = torch.Size([1, 512, 512])
|
||||
expected_slice = torch.tensor(
|
||||
[[2804.6260, 2792.5708, 2812.9263], [2772.0288, 2780.1118, 2796.2529], [2748.1094, 2766.6558, 2766.9834]]
|
||||
)
|
||||
elif model_name == "dpt-beit-large-384":
|
||||
# OK, checked
|
||||
expected_shape = torch.Size([1, 384, 384])
|
||||
expected_slice = torch.tensor(
|
||||
[[1783.2273, 1780.5729, 1792.6453], [1759.9817, 1765.5359, 1778.5002], [1739.1633, 1754.7903, 1757.1990]],
|
||||
)
|
||||
elif model_name == "dpt-beit-base-384":
|
||||
# OK, checked
|
||||
expected_shape = torch.Size([1, 384, 384])
|
||||
expected_slice = torch.tensor(
|
||||
[[2898.4482, 2891.3750, 2904.8079], [2858.6685, 2877.2615, 2894.4507], [2842.1235, 2854.1023, 2861.6328]],
|
||||
)
|
||||
|
||||
assert predicted_depth.shape == torch.Size(expected_shape)
|
||||
assert torch.allclose(predicted_depth[0, :3, :3], expected_slice)
|
||||
print("Looks ok!")
|
||||
|
||||
if pytorch_dump_folder_path is not None:
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model and processor to {pytorch_dump_folder_path}")
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if push_to_hub:
|
||||
print("Pushing model and processor to hub...")
|
||||
model.push_to_hub(repo_id=f"nielsr/{model_name}")
|
||||
processor.push_to_hub(repo_id=f"nielsr/{model_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default="dpt-beit-large-512",
|
||||
type=str,
|
||||
choices=["dpt-beit-large-512", "dpt-beit-large-384", "dpt-beit-base-384"],
|
||||
help="Name of the model you'd like to convert.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the output PyTorch model directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether to push the model to the hub after conversion.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
@ -44,7 +44,7 @@ def verify_out_features_out_indices(
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}")
|
||||
if any(idx >= len(stage_names) for idx in out_indices):
|
||||
raise ValueError("out_indices must be valid indices for stage_names {stage_names}, got {out_indices}")
|
||||
raise ValueError(f"out_indices must be valid indices for stage_names {stage_names}, got {out_indices}")
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
|
@ -1044,6 +1044,13 @@ class PretrainedBartModel(metaclass=DummyObject):
|
||||
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class BeitBackbone(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BeitForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -25,6 +25,7 @@ from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
@ -35,7 +36,9 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_BACKBONE_MAPPING,
|
||||
MODEL_MAPPING,
|
||||
BeitBackbone,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
@ -63,7 +66,7 @@ class BeitModelTester:
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@ -73,10 +76,11 @@ class BeitModelTester:
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
out_indices=[0, 1, 2, 3],
|
||||
out_indices=[1, 2, 3, 4],
|
||||
out_features=["stage1", "stage2", "stage3", "stage4"],
|
||||
):
|
||||
self.parent = parent
|
||||
self.vocab_size = 100
|
||||
self.vocab_size = vocab_size
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
@ -94,6 +98,7 @@ class BeitModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.out_indices = out_indices
|
||||
self.out_features = out_features
|
||||
self.num_labels = num_labels
|
||||
|
||||
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
||||
@ -129,6 +134,7 @@ class BeitModelTester:
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
out_indices=self.out_indices,
|
||||
out_features=self.out_features,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
|
||||
@ -138,6 +144,38 @@ class BeitModelTester:
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels, pixel_labels):
|
||||
model = BeitBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify hidden states
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
expected_height = expected_width = self.image_size // config.patch_size
|
||||
self.parent.assertListEqual(
|
||||
list(result.feature_maps[0].shape), [self.batch_size, self.hidden_size, expected_height, expected_width]
|
||||
)
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = BeitBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
self.parent.assertListEqual(
|
||||
list(result.feature_maps[0].shape), [self.batch_size, self.hidden_size, expected_height, expected_width]
|
||||
)
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
|
||||
def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
|
||||
model = BeitForMaskedImageModeling(config=config)
|
||||
model.to(torch_device)
|
||||
@ -192,7 +230,13 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
|
||||
all_model_classes = (
|
||||
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation)
|
||||
(
|
||||
BeitModel,
|
||||
BeitForImageClassification,
|
||||
BeitForMaskedImageModeling,
|
||||
BeitForSemanticSegmentation,
|
||||
BeitBackbone,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@ -226,6 +270,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="BEiT does not support feedforward chunking yet")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -239,6 +287,10 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
@ -260,7 +312,11 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# we don't test BeitForMaskedImageModeling
|
||||
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
|
||||
if model_class in [
|
||||
*get_values(MODEL_MAPPING),
|
||||
*get_values(MODEL_FOR_BACKBONE_MAPPING),
|
||||
BeitForMaskedImageModeling,
|
||||
]:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
@ -281,7 +337,8 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
# we don't test BeitForMaskedImageModeling
|
||||
if (
|
||||
model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]
|
||||
model_class
|
||||
in [*get_values(MODEL_MAPPING), *get_values(MODEL_FOR_BACKBONE_MAPPING), BeitForMaskedImageModeling]
|
||||
or not model_class.supports_gradient_checkpointing
|
||||
):
|
||||
continue
|
||||
@ -487,3 +544,12 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
segmentation = image_processor.post_process_semantic_segmentation(outputs=outputs)
|
||||
expected_shape = torch.Size((160, 160))
|
||||
self.assertEqual(segmentation[0].shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BeitBackboneTest(unittest.TestCase, BackboneTesterMixin):
|
||||
all_model_classes = (BeitBackbone,) if is_torch_available() else ()
|
||||
config_class = BeitConfig
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BeitModelTester(self)
|
||||
|
@ -963,6 +963,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
"TensorFlowBenchmark",
|
||||
"TensorFlowBenchmarkArguments",
|
||||
"AutoBackbone",
|
||||
"BeitBackbone",
|
||||
"BitBackbone",
|
||||
"ConvNextBackbone",
|
||||
"ConvNextV2Backbone",
|
||||
|
Loading…
Reference in New Issue
Block a user