mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add TextNet (#34979)
* WIP * Add config and modeling for Fast model * Refactor modeling and add tests * More changes * WIP * Add tests * Add conversion script * Add conversion scripts, integration tests, image processor * Fix style and copies * Add fast model to init * Add fast model in docs and other places * Fix import of cv2 * Rename image processing method * Fix build * Fix Build * fix style and fix copies * Fix build * Fix build * Fix Build * Clean up docstrings * Fix Build * Fix Build * Fix Build * Fix build * Add test for image_processing_fast and add documentation tests * some refactorings * Fix failing tests * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Introduce TextNet * Fix failures * Refactor textnet model * Fix failures * Add cv2 to setup * Fix failures * Fix failures * Add CV2 dependency * Fix bugs * Fix build issue * Fix failures * Remove textnet from modeling fast * Fix build and other things * Fix build * some cleanups * some cleanups * Some more cleanups * Fix build * Incorporate PR feedbacks * More cleanup * More cleanup * More cleanup * Fix build * Remove all the references of fast model * More cleanup * Fix build * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Fix Build * Fix build * Fix build * Fix build * Fix build * Fix build * Incorporate PR feedbacks * Fix style * Fix build * Incorporate PR feedbacks * Fix image processing mean and std * Incorporate PR feedbacks * fix build failure * Add assertion to image processor * Incorporate PR feedbacks * Incorporate PR feedbacks * fix style failures * fix build * Fix Imageclassification's linear layer, also introduce TextNetImageProcessor * Fix build * Fix build * Fix build * Fix build * Incorporate PR feedbacks * Incorporate PR feedbacks * Fix build * Incorporate PR feedbacks * Remove some script * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Incorporate PR feedbacks * Fix image processing in textnet * Incorporate PR Feedbacks * Fix CI failures * Fix failing test * Fix failing test * Fix failing test * Fix failing test * Fix failing test * Fix failing test * Add textnet to readme * Improve readability * Incorporate PR feedbacks * fix code style * fix key error and convert working * tvlt shouldn't be here * fix test modeling test * Fix tests, make fixup * Make fixup * Make fixup * Remove TEXTNET_PRETRAINED_MODEL_ARCHIVE_LIST * improve type annotation Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update tests/models/textnet/test_image_processing_textnet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * improve type annotation Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * space typo Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * improve type annotation Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/textnet/configuration_textnet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * make conv layer kernel sizes and strides default to None * Update src/transformers/models/textnet/modeling_textnet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/textnet/modeling_textnet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * fix keyword bug * add batch init and make fixup * Make fixup * Update integration test * Add figure * Update textnet.md * add testing and fix errors (classification, imgprocess) * fix error check * make fixup * make fixup * revert to original docstring * add make style * remove conflict for now * Update modeling_auto.py got a confusion in `timm_wrapper` - was giving some conflicts * Update tests/models/textnet/test_modeling_textnet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/textnet/modeling_textnet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update tests/models/textnet/test_modeling_textnet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/textnet/modeling_textnet.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * add changes * Update textnet.md * add doc * add authors hf ckpt + rename * add feedback: classifier/docs --------- Co-authored-by: raghavanone <opensourcemaniacfreak@gmail.com> Co-authored-by: jadechoghari <jadechoghari@users.noreply.huggingface.co> Co-authored-by: Niels <niels.rogge1@gmail.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
parent
b05df6611e
commit
7176e06b52
@ -721,6 +721,8 @@
|
||||
title: Swin2SR
|
||||
- local: model_doc/table-transformer
|
||||
title: Table Transformer
|
||||
- local: model_doc/textnet
|
||||
title: TextNet
|
||||
- local: model_doc/timm_wrapper
|
||||
title: Timm Wrapper
|
||||
- local: model_doc/upernet
|
||||
|
@ -326,6 +326,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [Table Transformer](model_doc/table-transformer) | ✅ | ❌ | ❌ |
|
||||
| [TAPAS](model_doc/tapas) | ✅ | ✅ | ❌ |
|
||||
| [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ |
|
||||
| [TextNet](model_doc/textnet) | ✅ | ❌ | ❌ |
|
||||
| [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ |
|
||||
| [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ |
|
||||
| [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ |
|
||||
|
55
docs/source/en/model_doc/textnet.md
Normal file
55
docs/source/en/model_doc/textnet.md
Normal file
@ -0,0 +1,55 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# TextNet
|
||||
|
||||
## Overview
|
||||
|
||||
The TextNet model was proposed in [FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation](https://arxiv.org/abs/2111.02394) by Zhe Chen, Jiahao Wang, Wenhai Wang, Guo Chen, Enze Xie, Ping Luo, Tong Lu. TextNet is a vision backbone useful for text detection tasks. It is the result of neural architecture search (NAS) on backbones with reward function as text detection task (to provide powerful features for text detection).
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/fast_architecture.png"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> TextNet backbone as part of FAST. Taken from the <a href="https://arxiv.org/abs/2111.02394">original paper.</a> </small>
|
||||
|
||||
This model was contributed by [Raghavan](https://huggingface.co/Raghavan), [jadechoghari](https://huggingface.co/jadechoghari) and [nielsr](https://huggingface.co/nielsr).
|
||||
|
||||
## Usage tips
|
||||
|
||||
TextNet is mainly used as a backbone network for the architecture search of text detection. Each stage of the backbone network is comprised of a stride-2 convolution and searchable blocks.
|
||||
Specifically, we present a layer-level candidate set, defined as {conv3×3, conv1×3, conv3×1, identity}. As the 1×3 and 3×1 convolutions have asymmetric kernels and oriented structure priors, they may help to capture the features of extreme aspect-ratio and rotated text lines.
|
||||
|
||||
TextNet is the backbone for Fast, but can also be used as an efficient text/image classification, we add a `TextNetForImageClassification` as is it would allow people to train an image classifier on top of the pre-trained textnet weights
|
||||
|
||||
## TextNetConfig
|
||||
|
||||
[[autodoc]] TextNetConfig
|
||||
|
||||
## TextNetImageProcessor
|
||||
|
||||
[[autodoc]] TextNetImageProcessor
|
||||
- preprocess
|
||||
|
||||
## TextNetModel
|
||||
|
||||
[[autodoc]] TextNetModel
|
||||
- forward
|
||||
|
||||
## TextNetForImageClassification
|
||||
|
||||
[[autodoc]] TextNetForImageClassification
|
||||
- forward
|
||||
|
@ -789,6 +789,7 @@ _import_structure = {
|
||||
"TapasConfig",
|
||||
"TapasTokenizer",
|
||||
],
|
||||
"models.textnet": ["TextNetConfig"],
|
||||
"models.time_series_transformer": ["TimeSeriesTransformerConfig"],
|
||||
"models.timesformer": ["TimesformerConfig"],
|
||||
"models.timm_backbone": ["TimmBackboneConfig"],
|
||||
@ -1258,6 +1259,7 @@ else:
|
||||
_import_structure["models.siglip"].append("SiglipImageProcessor")
|
||||
_import_structure["models.superpoint"].extend(["SuperPointImageProcessor"])
|
||||
_import_structure["models.swin2sr"].append("Swin2SRImageProcessor")
|
||||
_import_structure["models.textnet"].extend(["TextNetImageProcessor"])
|
||||
_import_structure["models.tvp"].append("TvpImageProcessor")
|
||||
_import_structure["models.video_llava"].append("VideoLlavaImageProcessor")
|
||||
_import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"])
|
||||
@ -3584,6 +3586,14 @@ else:
|
||||
"load_tf_weights_in_tapas",
|
||||
]
|
||||
)
|
||||
_import_structure["models.textnet"].extend(
|
||||
[
|
||||
"TextNetBackbone",
|
||||
"TextNetForImageClassification",
|
||||
"TextNetModel",
|
||||
"TextNetPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.time_series_transformer"].extend(
|
||||
[
|
||||
"TimeSeriesTransformerForPrediction",
|
||||
@ -5813,6 +5823,7 @@ if TYPE_CHECKING:
|
||||
TapasConfig,
|
||||
TapasTokenizer,
|
||||
)
|
||||
from .models.textnet import TextNetConfig
|
||||
from .models.time_series_transformer import (
|
||||
TimeSeriesTransformerConfig,
|
||||
)
|
||||
@ -6293,6 +6304,7 @@ if TYPE_CHECKING:
|
||||
from .models.siglip import SiglipImageProcessor
|
||||
from .models.superpoint import SuperPointImageProcessor
|
||||
from .models.swin2sr import Swin2SRImageProcessor
|
||||
from .models.textnet import TextNetImageProcessor
|
||||
from .models.tvp import TvpImageProcessor
|
||||
from .models.video_llava import VideoLlavaImageProcessor
|
||||
from .models.videomae import VideoMAEFeatureExtractor, VideoMAEImageProcessor
|
||||
@ -8155,6 +8167,12 @@ if TYPE_CHECKING:
|
||||
TapasPreTrainedModel,
|
||||
load_tf_weights_in_tapas,
|
||||
)
|
||||
from .models.textnet import (
|
||||
TextNetBackbone,
|
||||
TextNetForImageClassification,
|
||||
TextNetModel,
|
||||
TextNetPreTrainedModel,
|
||||
)
|
||||
from .models.time_series_transformer import (
|
||||
TimeSeriesTransformerForPrediction,
|
||||
TimeSeriesTransformerModel,
|
||||
|
@ -252,6 +252,7 @@ from . import (
|
||||
t5,
|
||||
table_transformer,
|
||||
tapas,
|
||||
textnet,
|
||||
time_series_transformer,
|
||||
timesformer,
|
||||
timm_backbone,
|
||||
|
@ -279,6 +279,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("t5", "T5Config"),
|
||||
("table-transformer", "TableTransformerConfig"),
|
||||
("tapas", "TapasConfig"),
|
||||
("textnet", "TextNetConfig"),
|
||||
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
||||
("timesformer", "TimesformerConfig"),
|
||||
("timm_backbone", "TimmBackboneConfig"),
|
||||
@ -610,6 +611,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("table-transformer", "Table Transformer"),
|
||||
("tapas", "TAPAS"),
|
||||
("tapex", "TAPEX"),
|
||||
("textnet", "TextNet"),
|
||||
("time_series_transformer", "Time Series Transformer"),
|
||||
("timesformer", "TimeSformer"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
|
@ -257,6 +257,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("t5", "T5Model"),
|
||||
("table-transformer", "TableTransformerModel"),
|
||||
("tapas", "TapasModel"),
|
||||
("textnet", "TextNetModel"),
|
||||
("time_series_transformer", "TimeSeriesTransformerModel"),
|
||||
("timesformer", "TimesformerModel"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
@ -703,6 +704,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("swiftformer", "SwiftFormerForImageClassification"),
|
||||
("swin", "SwinForImageClassification"),
|
||||
("swinv2", "Swinv2ForImageClassification"),
|
||||
("textnet", "TextNetForImageClassification"),
|
||||
("timm_wrapper", "TimmWrapperForImageClassification"),
|
||||
("van", "VanForImageClassification"),
|
||||
("vit", "ViTForImageClassification"),
|
||||
@ -1391,6 +1393,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
("rt_detr_resnet", "RTDetrResNetBackbone"),
|
||||
("swin", "SwinBackbone"),
|
||||
("swinv2", "Swinv2Backbone"),
|
||||
("textnet", "TextNetBackbone"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("vitdet", "VitDetBackbone"),
|
||||
]
|
||||
|
28
src/transformers/models/textnet/__init__.py
Normal file
28
src/transformers/models/textnet/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_textnet import *
|
||||
from .image_processing_textnet import *
|
||||
from .modeling_textnet import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
135
src/transformers/models/textnet/configuration_textnet.py
Normal file
135
src/transformers/models/textnet/configuration_textnet.py
Normal file
@ -0,0 +1,135 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the Fast authors and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""TextNet model configuration"""
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
from transformers.utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TextNetConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`TextNextModel`]. It is used to instantiate a
|
||||
TextNext model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the
|
||||
[czczup/textnet-base](https://huggingface.co/czczup/textnet-base). Configuration objects inherit from
|
||||
[`PretrainedConfig`] and can be used to control the model outputs.Read the documentation from [`PretrainedConfig`]
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
stem_kernel_size (`int`, *optional*, defaults to 3):
|
||||
The kernel size for the initial convolution layer.
|
||||
stem_stride (`int`, *optional*, defaults to 2):
|
||||
The stride for the initial convolution layer.
|
||||
stem_num_channels (`int`, *optional*, defaults to 3):
|
||||
The num of channels in input for the initial convolution layer.
|
||||
stem_out_channels (`int`, *optional*, defaults to 64):
|
||||
The num of channels in out for the initial convolution layer.
|
||||
stem_act_func (`str`, *optional*, defaults to `"relu"`):
|
||||
The activation function for the initial convolution layer.
|
||||
image_size (`Tuple[int, int]`, *optional*, defaults to `[640, 640]`):
|
||||
The size (resolution) of each image.
|
||||
conv_layer_kernel_sizes (`List[List[List[int]]]`, *optional*):
|
||||
A list of stage-wise kernel sizes. If `None`, defaults to:
|
||||
`[[[3, 3], [3, 3], [3, 3]], [[3, 3], [1, 3], [3, 3], [3, 1]], [[3, 3], [3, 3], [3, 1], [1, 3]], [[3, 3], [3, 1], [1, 3], [3, 3]]]`.
|
||||
conv_layer_strides (`List[List[int]]`, *optional*):
|
||||
A list of stage-wise strides. If `None`, defaults to:
|
||||
`[[1, 2, 1], [2, 1, 1, 1], [2, 1, 1, 1], [2, 1, 1, 1]]`.
|
||||
hidden_sizes (`List[int]`, *optional*, defaults to `[64, 64, 128, 256, 512]`):
|
||||
Dimensionality (hidden size) at each stage.
|
||||
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the batch normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
||||
corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
|
||||
out_indices (`List[int]`, *optional*):
|
||||
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
||||
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
||||
If unset and `out_features` is unset, will default to the last stage.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import TextNetConfig, TextNetBackbone
|
||||
|
||||
>>> # Initializing a TextNetConfig
|
||||
>>> configuration = TextNetConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights)
|
||||
>>> model = TextNetBackbone(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "textnet"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stem_kernel_size=3,
|
||||
stem_stride=2,
|
||||
stem_num_channels=3,
|
||||
stem_out_channels=64,
|
||||
stem_act_func="relu",
|
||||
image_size=[640, 640],
|
||||
conv_layer_kernel_sizes=None,
|
||||
conv_layer_strides=None,
|
||||
hidden_sizes=[64, 64, 128, 256, 512],
|
||||
batch_norm_eps=1e-5,
|
||||
initializer_range=0.02,
|
||||
out_features=None,
|
||||
out_indices=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if conv_layer_kernel_sizes is None:
|
||||
conv_layer_kernel_sizes = [
|
||||
[[3, 3], [3, 3], [3, 3]],
|
||||
[[3, 3], [1, 3], [3, 3], [3, 1]],
|
||||
[[3, 3], [3, 3], [3, 1], [1, 3]],
|
||||
[[3, 3], [3, 1], [1, 3], [3, 3]],
|
||||
]
|
||||
if conv_layer_strides is None:
|
||||
conv_layer_strides = [[1, 2, 1], [2, 1, 1, 1], [2, 1, 1, 1], [2, 1, 1, 1]]
|
||||
|
||||
self.stem_kernel_size = stem_kernel_size
|
||||
self.stem_stride = stem_stride
|
||||
self.stem_num_channels = stem_num_channels
|
||||
self.stem_out_channels = stem_out_channels
|
||||
self.stem_act_func = stem_act_func
|
||||
|
||||
self.image_size = image_size
|
||||
self.conv_layer_kernel_sizes = conv_layer_kernel_sizes
|
||||
self.conv_layer_strides = conv_layer_strides
|
||||
|
||||
self.initializer_range = initializer_range
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.batch_norm_eps = batch_norm_eps
|
||||
|
||||
self.depths = [len(layer) for layer in self.conv_layer_kernel_sizes]
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, 5)]
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TextNetConfig"]
|
208
src/transformers/models/textnet/convert_textnet_to_hf.py
Normal file
208
src/transformers/models/textnet/convert_textnet_to_hf.py
Normal file
@ -0,0 +1,208 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from transformers import TextNetBackbone, TextNetConfig, TextNetImageProcessor
|
||||
|
||||
|
||||
tiny_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_tiny.config"
|
||||
small_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_small.config"
|
||||
base_config_url = "https://raw.githubusercontent.com/czczup/FAST/main/config/fast/nas-configs/fast_base.config"
|
||||
|
||||
rename_key_mappings = {
|
||||
"module.backbone": "textnet",
|
||||
"first_conv": "stem",
|
||||
"bn": "batch_norm",
|
||||
"ver": "vertical",
|
||||
"hor": "horizontal",
|
||||
}
|
||||
|
||||
|
||||
def prepare_config(size_config_url, size):
|
||||
config_dict = json.loads(requests.get(size_config_url).text)
|
||||
|
||||
backbone_config = {}
|
||||
for stage_ix in range(1, 5):
|
||||
stage_config = config_dict[f"stage{stage_ix}"]
|
||||
|
||||
merged_dict = {}
|
||||
|
||||
# Iterate through the list of dictionaries
|
||||
for layer in stage_config:
|
||||
for key, value in layer.items():
|
||||
if key != "name":
|
||||
# Check if the key is already in the merged_dict
|
||||
if key in merged_dict:
|
||||
merged_dict[key].append(value)
|
||||
else:
|
||||
# If the key is not in merged_dict, create a new list with the value
|
||||
merged_dict[key] = [value]
|
||||
backbone_config[f"stage{stage_ix}"] = merged_dict
|
||||
|
||||
neck_in_channels = []
|
||||
neck_out_channels = []
|
||||
neck_kernel_size = []
|
||||
neck_stride = []
|
||||
neck_dilation = []
|
||||
neck_groups = []
|
||||
|
||||
for i in range(1, 5):
|
||||
layer_key = f"reduce_layer{i}"
|
||||
layer_dict = config_dict["neck"].get(layer_key)
|
||||
|
||||
if layer_dict:
|
||||
# Append values to the corresponding lists
|
||||
neck_in_channels.append(layer_dict["in_channels"])
|
||||
neck_out_channels.append(layer_dict["out_channels"])
|
||||
neck_kernel_size.append(layer_dict["kernel_size"])
|
||||
neck_stride.append(layer_dict["stride"])
|
||||
neck_dilation.append(layer_dict["dilation"])
|
||||
neck_groups.append(layer_dict["groups"])
|
||||
|
||||
textnet_config = TextNetConfig(
|
||||
stem_kernel_size=config_dict["first_conv"]["kernel_size"],
|
||||
stem_stride=config_dict["first_conv"]["stride"],
|
||||
stem_num_channels=config_dict["first_conv"]["in_channels"],
|
||||
stem_out_channels=config_dict["first_conv"]["out_channels"],
|
||||
stem_act_func=config_dict["first_conv"]["act_func"],
|
||||
conv_layer_kernel_sizes=[
|
||||
backbone_config["stage1"]["kernel_size"],
|
||||
backbone_config["stage2"]["kernel_size"],
|
||||
backbone_config["stage3"]["kernel_size"],
|
||||
backbone_config["stage4"]["kernel_size"],
|
||||
],
|
||||
conv_layer_strides=[
|
||||
backbone_config["stage1"]["stride"],
|
||||
backbone_config["stage2"]["stride"],
|
||||
backbone_config["stage3"]["stride"],
|
||||
backbone_config["stage4"]["stride"],
|
||||
],
|
||||
hidden_sizes=[
|
||||
config_dict["first_conv"]["out_channels"],
|
||||
backbone_config["stage1"]["out_channels"][-1],
|
||||
backbone_config["stage2"]["out_channels"][-1],
|
||||
backbone_config["stage3"]["out_channels"][-1],
|
||||
backbone_config["stage4"]["out_channels"][-1],
|
||||
],
|
||||
out_features=["stage1", "stage2", "stage3", "stage4"],
|
||||
out_indices=[1, 2, 3, 4],
|
||||
)
|
||||
|
||||
return textnet_config
|
||||
|
||||
|
||||
def convert_textnet_checkpoint(checkpoint_url, checkpoint_config_filename, pytorch_dump_folder_path):
|
||||
config_filepath = hf_hub_download(repo_id="Raghavan/fast_model_config_files", filename="fast_model_configs.json")
|
||||
|
||||
with open(config_filepath) as f:
|
||||
content = json.loads(f.read())
|
||||
|
||||
size = content[checkpoint_config_filename]["short_size"]
|
||||
|
||||
if "tiny" in content[checkpoint_config_filename]["config"]:
|
||||
config = prepare_config(tiny_config_url, size)
|
||||
expected_slice_backbone = torch.tensor(
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.5300, 0.0000, 0.0000, 0.0000, 0.0000, 1.1221]
|
||||
)
|
||||
elif "small" in content[checkpoint_config_filename]["config"]:
|
||||
config = prepare_config(small_config_url, size)
|
||||
expected_slice_backbone = torch.tensor(
|
||||
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1394]
|
||||
)
|
||||
else:
|
||||
config = prepare_config(base_config_url, size)
|
||||
expected_slice_backbone = torch.tensor(
|
||||
[0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000]
|
||||
)
|
||||
|
||||
model = TextNetBackbone(config)
|
||||
textnet_image_processor = TextNetImageProcessor(size={"shortest_edge": size})
|
||||
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["ema"]
|
||||
state_dict_changed = OrderedDict()
|
||||
for key in state_dict:
|
||||
if "backbone" in key:
|
||||
val = state_dict[key]
|
||||
new_key = key
|
||||
for search, replacement in rename_key_mappings.items():
|
||||
if search in new_key:
|
||||
new_key = new_key.replace(search, replacement)
|
||||
|
||||
pattern = r"textnet\.stage(\d)"
|
||||
|
||||
def adjust_stage(match):
|
||||
stage_number = int(match.group(1)) - 1
|
||||
return f"textnet.encoder.stages.{stage_number}.stage"
|
||||
|
||||
# Using regex to find and replace the pattern in the string
|
||||
new_key = re.sub(pattern, adjust_stage, new_key)
|
||||
state_dict_changed[new_key] = val
|
||||
model.load_state_dict(state_dict_changed)
|
||||
model.eval()
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
||||
|
||||
original_pixel_values = torch.tensor(
|
||||
[0.1939, 0.3481, 0.4166, 0.3309, 0.4508, 0.4679, 0.4851, 0.4851, 0.3309, 0.4337]
|
||||
)
|
||||
pixel_values = textnet_image_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
assert torch.allclose(original_pixel_values, pixel_values[0][0][3][:10], atol=1e-4)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(pixel_values)
|
||||
|
||||
assert torch.allclose(output["feature_maps"][-1][0][10][12][:10].detach(), expected_slice_backbone, atol=1e-3)
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
textnet_image_processor.save_pretrained(pytorch_dump_folder_path)
|
||||
logging.info("The converted weights are saved here : " + pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_url",
|
||||
default="https://github.com/czczup/FAST/releases/download/release/fast_base_ic17mlt_640.pth",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_config_filename",
|
||||
default="fast_base_ic17mlt_640.py",
|
||||
type=str,
|
||||
help="URL to the original PyTorch checkpoint (.pth file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_textnet_checkpoint(
|
||||
args.checkpoint_url,
|
||||
args.checkpoint_config_filename,
|
||||
args.pytorch_dump_folder_path,
|
||||
)
|
355
src/transformers/models/textnet/image_processing_textnet.py
Normal file
355
src/transformers/models/textnet/image_processing_textnet.py
Normal file
@ -0,0 +1,355 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for TextNet."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from ...image_transforms import (
|
||||
convert_to_rgb,
|
||||
get_resize_output_image_size,
|
||||
resize,
|
||||
to_channel_dimension_format,
|
||||
)
|
||||
from ...image_utils import (
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
validate_kwargs,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_vision_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
class TextNetImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a TextNet image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
|
||||
`do_resize` in the `preprocess` method.
|
||||
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
|
||||
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
|
||||
method.
|
||||
size_divisor (`int`, *optional*, defaults to 32):
|
||||
Ensures height and width are rounded to a multiple of this value after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||
do_center_crop (`bool`, *optional*, defaults to `False`):
|
||||
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
|
||||
`preprocess` method.
|
||||
crop_size (`Dict[str, int]` *optional*, defaults to 224):
|
||||
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
|
||||
method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||
the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||
method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
size_divisor: int = 32,
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
do_center_crop: bool = False,
|
||||
crop_size: Dict[str, int] = None,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = IMAGENET_DEFAULT_MEAN,
|
||||
image_std: Optional[Union[float, List[float]]] = IMAGENET_DEFAULT_STD,
|
||||
do_convert_rgb: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"shortest_edge": 224}
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.size_divisor = size_divisor
|
||||
self.resample = resample
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
||||
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
self._valid_processor_keys = [
|
||||
"images",
|
||||
"do_resize",
|
||||
"size",
|
||||
"size_divisor",
|
||||
"resample",
|
||||
"do_center_crop",
|
||||
"crop_size",
|
||||
"do_rescale",
|
||||
"rescale_factor",
|
||||
"do_normalize",
|
||||
"image_mean",
|
||||
"image_std",
|
||||
"do_convert_rgb",
|
||||
"return_tensors",
|
||||
"data_format",
|
||||
"input_data_format",
|
||||
]
|
||||
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image. The shortest edge of the image is resized to size["shortest_edge"] , with the longest edge
|
||||
resized to keep the input aspect ratio. Both the height and width are resized to be divisible by 32.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Size of the output image.
|
||||
size_divisor (`int`, *optional*, defaults to `32`):
|
||||
Ensures height and width are rounded to a multiple of this value after resizing.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
||||
Resampling filter to use when resiizing the image.
|
||||
data_format (`str` or `ChannelDimension`, *optional*):
|
||||
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format of the input image. If not provided, it will be inferred.
|
||||
default_to_square (`bool`, *optional*, defaults to `False`):
|
||||
The value to be passed to `get_size_dict` as `default_to_square` when computing the image size. If the
|
||||
`size` argument in `get_size_dict` is an `int`, it determines whether to default to a square image or
|
||||
not.Note that this attribute is not used in computing `crop_size` via calling `get_size_dict`.
|
||||
"""
|
||||
if "shortest_edge" in size:
|
||||
size = size["shortest_edge"]
|
||||
elif "height" in size and "width" in size:
|
||||
size = (size["height"], size["width"])
|
||||
else:
|
||||
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
|
||||
|
||||
height, width = get_resize_output_image_size(
|
||||
image, size=size, input_data_format=input_data_format, default_to_square=False
|
||||
)
|
||||
if height % self.size_divisor != 0:
|
||||
height += self.size_divisor - (height % self.size_divisor)
|
||||
if width % self.size_divisor != 0:
|
||||
width += self.size_divisor - (width % self.size_divisor)
|
||||
|
||||
return resize(
|
||||
image,
|
||||
size=(height, width),
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: bool = None,
|
||||
size: Dict[str, int] = None,
|
||||
size_divisor: int = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_center_crop: bool = None,
|
||||
crop_size: int = None,
|
||||
do_rescale: bool = None,
|
||||
rescale_factor: float = None,
|
||||
do_normalize: bool = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
||||
the longest edge resized to keep the input aspect ratio.
|
||||
size_divisor (`int`, *optional*, defaults to `32`):
|
||||
Ensures height and width are rounded to a multiple of this value after resizing.
|
||||
resample (`int`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
||||
has an effect if `do_resize` is set to `True`.
|
||||
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
|
||||
Whether to center crop the image.
|
||||
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
||||
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image.
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
||||
`True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, param_name="size", default_to_square=False)
|
||||
size_divisor = size_divisor if size_divisor is not None else self.size_divisor
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
crop_size = crop_size if crop_size is not None else self.crop_size
|
||||
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
|
||||
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
|
||||
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
validate_preprocess_arguments(
|
||||
do_rescale=do_rescale,
|
||||
rescale_factor=rescale_factor,
|
||||
do_normalize=do_normalize,
|
||||
image_mean=image_mean,
|
||||
image_std=image_std,
|
||||
do_center_crop=do_center_crop,
|
||||
crop_size=crop_size,
|
||||
do_resize=do_resize,
|
||||
size=size,
|
||||
resample=resample,
|
||||
)
|
||||
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
all_images = []
|
||||
for image in images:
|
||||
if do_resize:
|
||||
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
|
||||
if do_center_crop:
|
||||
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
|
||||
|
||||
if do_rescale:
|
||||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
|
||||
if do_normalize:
|
||||
image = self.normalize(
|
||||
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
||||
)
|
||||
|
||||
all_images.append(image)
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||
for image in all_images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
|
||||
__all__ = ["TextNetImageProcessor"]
|
487
src/transformers/models/textnet/modeling_textnet.py
Normal file
487
src/transformers/models/textnet/modeling_textnet.py
Normal file
@ -0,0 +1,487 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch TextNet model."""
|
||||
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers import PreTrainedModel, add_start_docstrings
|
||||
from transformers.activations import ACT2CLS
|
||||
from transformers.modeling_outputs import (
|
||||
BackboneOutput,
|
||||
BaseModelOutputWithNoAttention,
|
||||
BaseModelOutputWithPoolingAndNoAttention,
|
||||
ImageClassifierOutputWithNoAttention,
|
||||
)
|
||||
from transformers.models.textnet.configuration_textnet import TextNetConfig
|
||||
from transformers.utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.backbone_utils import BackboneMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "TextNetConfig"
|
||||
_CHECKPOINT_FOR_DOC = "czczup/textnet-base"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 512, 20, 27]
|
||||
|
||||
|
||||
class TextNetConvLayer(nn.Module):
|
||||
def __init__(self, config: TextNetConfig):
|
||||
super().__init__()
|
||||
|
||||
self.kernel_size = config.stem_kernel_size
|
||||
self.stride = config.stem_stride
|
||||
self.activation_function = config.stem_act_func
|
||||
|
||||
padding = (
|
||||
(config.kernel_size[0] // 2, config.kernel_size[1] // 2)
|
||||
if isinstance(config.stem_kernel_size, tuple)
|
||||
else config.stem_kernel_size // 2
|
||||
)
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
config.stem_num_channels,
|
||||
config.stem_out_channels,
|
||||
kernel_size=config.stem_kernel_size,
|
||||
stride=config.stem_stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
)
|
||||
self.batch_norm = nn.BatchNorm2d(config.stem_out_channels, config.batch_norm_eps)
|
||||
|
||||
self.activation = nn.Identity()
|
||||
if self.activation_function is not None:
|
||||
self.activation = ACT2CLS[self.activation_function]()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.batch_norm(hidden_states)
|
||||
return self.activation(hidden_states)
|
||||
|
||||
|
||||
class TextNetRepConvLayer(nn.Module):
|
||||
r"""
|
||||
This layer supports re-parameterization by combining multiple convolutional branches
|
||||
(e.g., main convolution, vertical, horizontal, and identity branches) during training.
|
||||
At inference time, these branches can be collapsed into a single convolution for
|
||||
efficiency, as per the re-parameterization paradigm.
|
||||
|
||||
The "Rep" in the name stands for "re-parameterization" (introduced by RepVGG).
|
||||
"""
|
||||
|
||||
def __init__(self, config: TextNetConfig, in_channels: int, out_channels: int, kernel_size: int, stride: int):
|
||||
super().__init__()
|
||||
|
||||
self.num_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
||||
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
||||
|
||||
self.activation_function = nn.ReLU()
|
||||
|
||||
self.main_conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=False,
|
||||
)
|
||||
self.main_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
|
||||
|
||||
vertical_padding = ((kernel_size[0] - 1) // 2, 0)
|
||||
horizontal_padding = (0, (kernel_size[1] - 1) // 2)
|
||||
|
||||
if kernel_size[1] != 1:
|
||||
self.vertical_conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(kernel_size[0], 1),
|
||||
stride=stride,
|
||||
padding=vertical_padding,
|
||||
bias=False,
|
||||
)
|
||||
self.vertical_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
|
||||
else:
|
||||
self.vertical_conv, self.vertical_batch_norm = None, None
|
||||
|
||||
if kernel_size[0] != 1:
|
||||
self.horizontal_conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=(1, kernel_size[1]),
|
||||
stride=stride,
|
||||
padding=horizontal_padding,
|
||||
bias=False,
|
||||
)
|
||||
self.horizontal_batch_norm = nn.BatchNorm2d(num_features=out_channels, eps=config.batch_norm_eps)
|
||||
else:
|
||||
self.horizontal_conv, self.horizontal_batch_norm = None, None
|
||||
|
||||
self.rbr_identity = (
|
||||
nn.BatchNorm2d(num_features=in_channels, eps=config.batch_norm_eps)
|
||||
if out_channels == in_channels and stride == 1
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
main_outputs = self.main_conv(hidden_states)
|
||||
main_outputs = self.main_batch_norm(main_outputs)
|
||||
|
||||
# applies a convolution with a vertical kernel
|
||||
if self.vertical_conv is not None:
|
||||
vertical_outputs = self.vertical_conv(hidden_states)
|
||||
vertical_outputs = self.vertical_batch_norm(vertical_outputs)
|
||||
main_outputs = main_outputs + vertical_outputs
|
||||
|
||||
# applies a convolution with a horizontal kernel
|
||||
if self.horizontal_conv is not None:
|
||||
horizontal_outputs = self.horizontal_conv(hidden_states)
|
||||
horizontal_outputs = self.horizontal_batch_norm(horizontal_outputs)
|
||||
main_outputs = main_outputs + horizontal_outputs
|
||||
|
||||
if self.rbr_identity is not None:
|
||||
id_out = self.rbr_identity(hidden_states)
|
||||
main_outputs = main_outputs + id_out
|
||||
|
||||
return self.activation_function(main_outputs)
|
||||
|
||||
|
||||
class TextNetStage(nn.Module):
|
||||
def __init__(self, config: TextNetConfig, depth: int):
|
||||
super().__init__()
|
||||
kernel_size = config.conv_layer_kernel_sizes[depth]
|
||||
stride = config.conv_layer_strides[depth]
|
||||
|
||||
num_layers = len(kernel_size)
|
||||
stage_in_channel_size = config.hidden_sizes[depth]
|
||||
stage_out_channel_size = config.hidden_sizes[depth + 1]
|
||||
|
||||
in_channels = [stage_in_channel_size] + [stage_out_channel_size] * (num_layers - 1)
|
||||
out_channels = [stage_out_channel_size] * num_layers
|
||||
|
||||
stage = []
|
||||
for stage_config in zip(in_channels, out_channels, kernel_size, stride):
|
||||
stage.append(TextNetRepConvLayer(config, *stage_config))
|
||||
self.stage = nn.ModuleList(stage)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
for block in self.stage:
|
||||
hidden_state = block(hidden_state)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class TextNetEncoder(nn.Module):
|
||||
def __init__(self, config: TextNetConfig):
|
||||
super().__init__()
|
||||
|
||||
stages = []
|
||||
num_stages = len(config.conv_layer_kernel_sizes)
|
||||
for stage_ix in range(num_stages):
|
||||
stages.append(TextNetStage(config, stage_ix))
|
||||
|
||||
self.stages = nn.ModuleList(stages)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> BaseModelOutputWithNoAttention:
|
||||
hidden_states = [hidden_state]
|
||||
for stage in self.stages:
|
||||
hidden_state = stage(hidden_state)
|
||||
hidden_states.append(hidden_state)
|
||||
|
||||
if not return_dict:
|
||||
output = (hidden_state,)
|
||||
return output + (hidden_states,) if output_hidden_states else output
|
||||
|
||||
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
|
||||
|
||||
|
||||
TEXTNET_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`TextNetConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
TEXTNET_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
|
||||
[`TextNetImageProcessor.__call__`] for details.
|
||||
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class TextNetPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = TextNetConfig
|
||||
base_model_prefix = "textnet"
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
module.weight.data.fill_(1.0)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Textnet model outputting raw features without any specific head on top.",
|
||||
TEXTNET_START_DOCSTRING,
|
||||
)
|
||||
class TextNetModel(TextNetPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.stem = TextNetConvLayer(config)
|
||||
self.encoder = TextNetEncoder(config)
|
||||
self.pooler = nn.AdaptiveAvgPool2d((2, 2))
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPoolingAndNoAttention,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def forward(
|
||||
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
|
||||
) -> Union[Tuple[Any, List[Any]], Tuple[Any], BaseModelOutputWithPoolingAndNoAttention]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
hidden_state = self.stem(pixel_values)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_state, output_hidden_states=output_hidden_states, return_dict=return_dict
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
pooled_output = self.pooler(last_hidden_state)
|
||||
|
||||
if not return_dict:
|
||||
output = (last_hidden_state, pooled_output)
|
||||
return output + (encoder_outputs[1],) if output_hidden_states else output
|
||||
|
||||
return BaseModelOutputWithPoolingAndNoAttention(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs[1] if output_hidden_states else None,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
TextNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
|
||||
ImageNet.
|
||||
""",
|
||||
TEXTNET_START_DOCSTRING,
|
||||
)
|
||||
class TextNetForImageClassification(TextNetPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.textnet = TextNetModel(config)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.flatten = nn.Flatten()
|
||||
self.fc = nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
|
||||
|
||||
# classification head
|
||||
self.classifier = nn.ModuleList([self.avg_pool, self.flatten])
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=ImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> ImageClassifierOutputWithNoAttention:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from transformers import TextNetForImageClassification, TextNetImageProcessor
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> processor = TextNetImageProcessor.from_pretrained("czczup/textnet-base")
|
||||
>>> model = TextNetForImageClassification.from_pretrained("czczup/textnet-base")
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt", size={"height": 640, "width": 640})
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
>>> outputs.logits.shape
|
||||
torch.Size([1, 2])
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.textnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
|
||||
last_hidden_state = outputs[0]
|
||||
for layer in self.classifier:
|
||||
last_hidden_state = layer(last_hidden_state)
|
||||
logits = self.fc(last_hidden_state)
|
||||
loss = None
|
||||
|
||||
if labels is not None:
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
TextNet backbone, to be used with frameworks like DETR and MaskFormer.
|
||||
""",
|
||||
TEXTNET_START_DOCSTRING,
|
||||
)
|
||||
class TextNetBackbone(TextNetPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.textnet = TextNetModel(config)
|
||||
self.num_features = config.hidden_sizes
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(TEXTNET_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
|
||||
) -> Union[Tuple[Tuple], BackboneOutput]:
|
||||
"""
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> from transformers import AutoImageProcessor, AutoBackbone
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> processor = AutoImageProcessor.from_pretrained("czczup/textnet-base")
|
||||
>>> model = AutoBackbone.from_pretrained("czczup/textnet-base")
|
||||
|
||||
>>> inputs = processor(image, return_tensors="pt")
|
||||
>>> with torch.no_grad():
|
||||
>>> outputs = model(**inputs)
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
outputs = self.textnet(pixel_values, output_hidden_states=True, return_dict=return_dict)
|
||||
|
||||
hidden_states = outputs.hidden_states if return_dict else outputs[2]
|
||||
|
||||
feature_maps = ()
|
||||
for idx, stage in enumerate(self.stage_names):
|
||||
if stage in self.out_features:
|
||||
feature_maps += (hidden_states[idx],)
|
||||
|
||||
if not return_dict:
|
||||
output = (feature_maps,)
|
||||
if output_hidden_states:
|
||||
hidden_states = outputs.hidden_states if return_dict else outputs[2]
|
||||
output += (hidden_states,)
|
||||
return output
|
||||
|
||||
return BackboneOutput(
|
||||
feature_maps=feature_maps,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=None,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TextNetBackbone", "TextNetModel", "TextNetPreTrainedModel", "TextNetForImageClassification"]
|
@ -9158,6 +9158,34 @@ def load_tf_weights_in_tapas(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_tapas, ["torch"])
|
||||
|
||||
|
||||
class TextNetBackbone(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TextNetForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TextNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TextNetPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TimeSeriesTransformerForPrediction(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -618,6 +618,13 @@ class Swin2SRImageProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class TextNetImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
|
||||
|
||||
class TvpImageProcessor(metaclass=DummyObject):
|
||||
_backends = ["vision"]
|
||||
|
||||
|
0
tests/models/textnet/__init__.py
Normal file
0
tests/models/textnet/__init__.py
Normal file
126
tests/models/textnet/test_image_processing_textnet.py
Normal file
126
tests/models/textnet/test_image_processing_textnet.py
Normal file
@ -0,0 +1,126 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import TextNetImageProcessor
|
||||
|
||||
|
||||
class TextNetImageProcessingTester(unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
num_channels=3,
|
||||
image_size=18,
|
||||
min_resolution=30,
|
||||
max_resolution=400,
|
||||
do_resize=True,
|
||||
size=None,
|
||||
size_divisor=32,
|
||||
do_center_crop=True,
|
||||
crop_size=None,
|
||||
do_normalize=True,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
do_convert_rgb=True,
|
||||
):
|
||||
size = size if size is not None else {"shortest_edge": 20}
|
||||
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.min_resolution = min_resolution
|
||||
self.max_resolution = max_resolution
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.size_divisor = size_divisor
|
||||
self.do_center_crop = do_center_crop
|
||||
self.crop_size = crop_size
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
|
||||
def prepare_image_processor_dict(self):
|
||||
return {
|
||||
"do_resize": self.do_resize,
|
||||
"size": self.size,
|
||||
"size_divisor": self.size_divisor,
|
||||
"do_center_crop": self.do_center_crop,
|
||||
"crop_size": self.crop_size,
|
||||
"do_normalize": self.do_normalize,
|
||||
"image_mean": self.image_mean,
|
||||
"image_std": self.image_std,
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, images):
|
||||
return self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
return prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class TextNetImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = TextNetImageProcessor if is_vision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.image_processor_tester = TextNetImageProcessingTester(self)
|
||||
|
||||
@property
|
||||
def image_processor_dict(self):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "size_divisor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "center_crop"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 20})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
348
tests/models/textnet/test_modeling_textnet.py
Normal file
348
tests/models/textnet/test_modeling_textnet.py
Normal file
@ -0,0 +1,348 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the Fast authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch TextNet model."""
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from transformers import TextNetConfig
|
||||
from transformers.models.textnet.image_processing_textnet import TextNetImageProcessor
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import TextNetBackbone, TextNetForImageClassification, TextNetModel
|
||||
|
||||
|
||||
class TextNetConfigTester(ConfigTester):
|
||||
def create_and_test_config_common_properties(self):
|
||||
config = self.config_class(**self.inputs_dict)
|
||||
self.parent.assertTrue(hasattr(config, "hidden_sizes"))
|
||||
self.parent.assertTrue(hasattr(config, "num_attention_heads"))
|
||||
self.parent.assertTrue(hasattr(config, "num_encoder_blocks"))
|
||||
|
||||
|
||||
class TextNetModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
stem_kernel_size=3,
|
||||
stem_stride=2,
|
||||
stem_in_channels=3,
|
||||
stem_out_channels=32,
|
||||
stem_act_func="relu",
|
||||
dropout_rate=0,
|
||||
ops_order="weight_bn_act",
|
||||
conv_layer_kernel_sizes=[
|
||||
[[3, 3]],
|
||||
[[3, 3]],
|
||||
[[3, 3]],
|
||||
[[3, 3]],
|
||||
],
|
||||
conv_layer_strides=[
|
||||
[2],
|
||||
[2],
|
||||
[2],
|
||||
[2],
|
||||
],
|
||||
out_features=["stage1", "stage2", "stage3", "stage4"],
|
||||
out_indices=[1, 2, 3, 4],
|
||||
batch_size=3,
|
||||
num_channels=3,
|
||||
image_size=[32, 32],
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
num_labels=3,
|
||||
hidden_sizes=[32, 32, 32, 32, 32],
|
||||
):
|
||||
self.parent = parent
|
||||
self.stem_kernel_size = stem_kernel_size
|
||||
self.stem_stride = stem_stride
|
||||
self.stem_in_channels = stem_in_channels
|
||||
self.stem_out_channels = stem_out_channels
|
||||
self.act_func = stem_act_func
|
||||
self.dropout_rate = dropout_rate
|
||||
self.ops_order = ops_order
|
||||
self.conv_layer_kernel_sizes = conv_layer_kernel_sizes
|
||||
self.conv_layer_strides = conv_layer_strides
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.num_labels = num_labels
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
self.num_stages = 5
|
||||
|
||||
def get_config(self):
|
||||
return TextNetConfig(
|
||||
stem_kernel_size=self.stem_kernel_size,
|
||||
stem_stride=self.stem_stride,
|
||||
stem_num_channels=self.stem_in_channels,
|
||||
stem_out_channels=self.stem_out_channels,
|
||||
act_func=self.act_func,
|
||||
dropout_rate=self.dropout_rate,
|
||||
ops_order=self.ops_order,
|
||||
conv_layer_kernel_sizes=self.conv_layer_kernel_sizes,
|
||||
conv_layer_strides=self.conv_layer_strides,
|
||||
out_features=self.out_features,
|
||||
out_indices=self.out_indices,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
image_size=self.image_size,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = TextNetModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
scale_h = self.image_size[0] // 32
|
||||
scale_w = self.image_size[1] // 32
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
(self.batch_size, self.hidden_sizes[-1], scale_h, scale_w),
|
||||
)
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = TextNetForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
|
||||
def create_and_check_backbone(self, config, pixel_values, labels):
|
||||
model = TextNetBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
scale_h = self.image_size[0] // 32
|
||||
scale_w = self.image_size[1] // 32
|
||||
self.parent.assertListEqual(
|
||||
list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 8 * scale_h, 8 * scale_w]
|
||||
)
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = TextNetBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
scale_h = self.image_size[0] // 32
|
||||
scale_w = self.image_size[1] // 32
|
||||
self.parent.assertListEqual(
|
||||
list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[0], scale_h, scale_w]
|
||||
)
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class TextNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
"""
|
||||
Here we also overwrite some tests of test_modeling_common.py, as TextNet does not use input_ids, inputs_embeds,
|
||||
attention_mask and seq_length.
|
||||
"""
|
||||
|
||||
all_model_classes = (TextNetModel, TextNetForImageClassification, TextNetBackbone) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{"feature-extraction": TextNetModel, "image-classification": TextNetForImageClassification}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
has_attentions = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TextNetModelTester(self)
|
||||
self.config_tester = TextNetConfigTester(self, config_class=TextNetConfig, has_text_modality=False)
|
||||
|
||||
@unittest.skip(reason="TextNet does not output attentions")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TextNet does not have input/output embeddings")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TextNet does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="TextNet does not support input and output embeddings")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_backbone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_backbone(*config_and_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
self.assertTrue(
|
||||
torch.all(module.weight == 1),
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.all(module.bias == 0),
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_stages)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[self.model_tester.image_size[0] // 2, self.model_tester.image_size[1] // 2],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
layers_type = ["preactivation", "bottleneck"]
|
||||
for model_class in self.all_model_classes:
|
||||
for layer_type in layers_type:
|
||||
config.layer_type = layer_type
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
@unittest.skip(reason="TextNet does not use feedforward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "czczup/textnet-base"
|
||||
model = TextNetModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
class TextNetModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
processor = TextNetImageProcessor.from_pretrained("czczup/textnet-base")
|
||||
model = TextNetModel.from_pretrained("czczup/textnet-base").to(torch_device)
|
||||
|
||||
# prepare image
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
inputs = processor(images=image, return_tensors="pt").to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
# verify logits
|
||||
self.assertEqual(output.logits.shape, torch.Size([1, 2]))
|
||||
expected_slice_backbone = torch.tensor(
|
||||
[0.9210, 0.6099, 0.0000, 0.0000, 0.0000, 0.0000, 3.2207, 2.6602, 1.8925, 0.0000],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertTrue(torch.allclose(output.feature_maps[-1][0][10][12][:10], expected_slice_backbone, atol=1e-3))
|
||||
|
||||
|
||||
@require_torch
|
||||
# Copied from tests.models.bit.test_modeling_bit.BitBackboneTest with Bit->TextNet
|
||||
class TextNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TextNetBackbone,) if is_torch_available() else ()
|
||||
config_class = TextNetConfig
|
||||
|
||||
has_attentions = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TextNetModelTester(self)
|
@ -1020,6 +1020,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
"ResNetBackbone",
|
||||
"SwinBackbone",
|
||||
"Swinv2Backbone",
|
||||
"TextNetBackbone",
|
||||
"TimmBackbone",
|
||||
"TimmBackboneConfig",
|
||||
"VitDetBackbone",
|
||||
|
Loading…
Reference in New Issue
Block a user