mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[SegFormer] TensorFlow port (#17910)
* add: segformer utils and img. classification. * add: segmentation layer. * feat: working implementation of segformer. * chore: remove unused variable. * add test, remaining modifications. * remove: unnecessary files. * add: rest of the files. Co-authored-by: matt <rocketknight1@gmail.com> * chore: remove ModuleList comment. * chore: apply make style. * chore: apply make fixup-copies. * add to check_repo.py * add decode head to IGNORE_NON_TESTED * chore: run make style. * chore: PR comments. * chore: minor changes to model doc. * tests: reduction across samples. * add a note on the space. * sort importats. * fix: reduction in loss computation. * chore: align loss function with that of NER. * chore: correct utils/documentation_tests.txt Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * chore: simplify the interpolation of logits in loss computation. * chore: return transposed logits when return_dict=False. * chore: add link to the tf fine-tuning repo. * address pr comments. * address niels's comments. * remove from_pt=True since tf weights are in. * remove comment from pt model. * address niels's comments. Co-authored-by: matt <rocketknight1@gmail.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
2c5747edfe
commit
561b9a8c00
@ -278,7 +278,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ |
|
||||
|
@ -36,13 +36,14 @@ The figure below illustrates the architecture of SegFormer. Taken from the [orig
|
||||
|
||||
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/segformer_architecture.png"/>
|
||||
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/NVlabs/SegFormer).
|
||||
This model was contributed by [nielsr](https://huggingface.co/nielsr). The TensorFlow version
|
||||
of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code can be found [here](https://github.com/NVlabs/SegFormer).
|
||||
|
||||
Tips:
|
||||
|
||||
- SegFormer consists of a hierarchical Transformer encoder, and a lightweight all-MLP decode head.
|
||||
- SegFormer consists of a hierarchical Transformer encoder, and a lightweight all-MLP decoder head.
|
||||
[`SegformerModel`] is the hierarchical Transformer encoder (which in the paper is also referred to
|
||||
as Mix Transformer or MiT). [`SegformerForSemanticSegmentation`] adds the all-MLP decode head on
|
||||
as Mix Transformer or MiT). [`SegformerForSemanticSegmentation`] adds the all-MLP decoder head on
|
||||
top to perform semantic segmentation of images. In addition, there's
|
||||
[`SegformerForImageClassification`] which can be used to - you guessed it - classify images. The
|
||||
authors of SegFormer first pre-trained the Transformer encoder on ImageNet-1k to classify images. Next, they throw
|
||||
@ -51,6 +52,9 @@ Tips:
|
||||
found on the [hub](https://huggingface.co/models?other=segformer).
|
||||
- The quickest way to get started with SegFormer is by checking the [example notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/SegFormer) (which showcase both inference and
|
||||
fine-tuning on custom data). One can also check out the [blog post](https://huggingface.co/blog/fine-tune-segformer) introducing SegFormer and illustrating how it can be fine-tuned on custom data.
|
||||
- TensorFlow users should refer to [this repository](https://github.com/deep-diver/segformer-tf-transformers) that shows off-the-shelf inference and fine-tuning.
|
||||
- One can also check out [this interactive demo on Hugging Face Spaces](https://huggingface.co/spaces/chansung/segformer-tf-transformers)
|
||||
to try out a SegFormer model on custom images.
|
||||
- SegFormer works on any input size, as it pads the input to be divisible by `config.patch_sizes`.
|
||||
- One can use [`SegformerFeatureExtractor`] to prepare images and corresponding segmentation maps
|
||||
for the model. Note that this feature extractor is fairly basic and does not include all data augmentations used in
|
||||
@ -65,7 +69,8 @@ Tips:
|
||||
used by [`SegformerForSemanticSegmentation`]). However, other datasets use the 0 index as
|
||||
background class and include this class as part of all labels. In that case, `reduce_labels` should be set to
|
||||
`False`, as loss should also be computed for the background class.
|
||||
- As most models, SegFormer comes in different sizes, the details of which can be found in the table below.
|
||||
- As most models, SegFormer comes in different sizes, the details of which can be found in the table below
|
||||
(taken from Table 7 of the [original paper](https://arxiv.org/abs/2105.15203)).
|
||||
|
||||
| **Model variant** | **Depths** | **Hidden sizes** | **Decoder hidden size** | **Params (M)** | **ImageNet-1k Top 1** |
|
||||
| :---------------: | ------------- | ------------------- | :---------------------: | :------------: | :-------------------: |
|
||||
@ -76,6 +81,10 @@ Tips:
|
||||
| MiT-b4 | [3, 8, 27, 3] | [64, 128, 320, 512] | 768 | 62.6 | 83.6 |
|
||||
| MiT-b5 | [3, 6, 40, 3] | [64, 128, 320, 512] | 768 | 82.0 | 83.8 |
|
||||
|
||||
Note that MiT in the above table refers to the Mix Transformer encoder backbone introduced in SegFormer. For
|
||||
SegFormer's results on the segmentation datasets like ADE20k, refer to the [paper](https://arxiv.org/abs/2105.15203).
|
||||
|
||||
|
||||
## SegformerConfig
|
||||
|
||||
[[autodoc]] SegformerConfig
|
||||
@ -104,3 +113,23 @@ Tips:
|
||||
|
||||
[[autodoc]] SegformerForSemanticSegmentation
|
||||
- forward
|
||||
|
||||
## TFSegformerDecodeHead
|
||||
|
||||
[[autodoc]] TFSegformerDecodeHead
|
||||
- call
|
||||
|
||||
## TFSegformerModel
|
||||
|
||||
[[autodoc]] TFSegformerModel
|
||||
- call
|
||||
|
||||
## TFSegformerForImageClassification
|
||||
|
||||
[[autodoc]] TFSegformerForImageClassification
|
||||
- call
|
||||
|
||||
## TFSegformerForSemanticSegmentation
|
||||
|
||||
[[autodoc]] TFSegformerForSemanticSegmentation
|
||||
- call
|
||||
|
@ -2430,6 +2430,16 @@ else:
|
||||
"TFRoFormerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.segformer"].extend(
|
||||
[
|
||||
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFSegformerDecodeHead",
|
||||
"TFSegformerForImageClassification",
|
||||
"TFSegformerForSemanticSegmentation",
|
||||
"TFSegformerModel",
|
||||
"TFSegformerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.speech_to_text"].extend(
|
||||
[
|
||||
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -4789,6 +4799,14 @@ if TYPE_CHECKING:
|
||||
TFRoFormerModel,
|
||||
TFRoFormerPreTrainedModel,
|
||||
)
|
||||
from .models.segformer import (
|
||||
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSegformerDecodeHead,
|
||||
TFSegformerForImageClassification,
|
||||
TFSegformerForSemanticSegmentation,
|
||||
TFSegformerModel,
|
||||
TFSegformerPreTrainedModel,
|
||||
)
|
||||
from .models.speech_to_text import (
|
||||
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSpeech2TextForConditionalGeneration,
|
||||
|
@ -68,6 +68,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("resnet", "TFResNetModel"),
|
||||
("roberta", "TFRobertaModel"),
|
||||
("roformer", "TFRoFormerModel"),
|
||||
("segformer", "TFSegformerModel"),
|
||||
("speech_to_text", "TFSpeech2TextModel"),
|
||||
("swin", "TFSwinModel"),
|
||||
("t5", "TFT5Model"),
|
||||
@ -180,6 +181,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
|
||||
("regnet", "TFRegNetForImageClassification"),
|
||||
("resnet", "TFResNetForImageClassification"),
|
||||
("segformer", "TFSegformerForImageClassification"),
|
||||
("swin", "TFSwinForImageClassification"),
|
||||
("vit", "TFViTForImageClassification"),
|
||||
]
|
||||
@ -189,6 +191,7 @@ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Semantic Segmentation mapping
|
||||
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
|
||||
("segformer", "TFSegformerForSemanticSegmentation"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -17,7 +17,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]}
|
||||
@ -46,6 +52,21 @@ else:
|
||||
"SegformerPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_segformer"] = [
|
||||
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFSegformerDecodeHead",
|
||||
"TFSegformerForImageClassification",
|
||||
"TFSegformerForSemanticSegmentation",
|
||||
"TFSegformerModel",
|
||||
"TFSegformerPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig
|
||||
@ -73,7 +94,20 @@ if TYPE_CHECKING:
|
||||
SegformerModel,
|
||||
SegformerPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_segformer import (
|
||||
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSegformerDecodeHead,
|
||||
TFSegformerForImageClassification,
|
||||
TFSegformerForSemanticSegmentation,
|
||||
TFSegformerModel,
|
||||
TFSegformerPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
@ -785,6 +785,8 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits # shape (batch_size, num_labels, height, width)
|
||||
>>> logits.shape
|
||||
(1, 150, 128, 128)
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
@ -804,7 +806,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.num_labels == 1:
|
||||
if not self.config.num_labels > 1:
|
||||
raise ValueError("The number of labels should be greater than one")
|
||||
else:
|
||||
# upsample logits to the images' original size
|
||||
|
878
src/transformers/models/segformer/modeling_tf_segformer.py
Normal file
878
src/transformers/models/segformer/modeling_tf_segformer.py
Normal file
@ -0,0 +1,878 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 NVIDIA The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" TensorFlow SegFormer model."""
|
||||
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput
|
||||
from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
|
||||
from ...tf_utils import shape_list, stable_softmax
|
||||
from ...utils import logging
|
||||
from .configuration_segformer import SegformerConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "SegformerConfig"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "SegformerFeatureExtractor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16]
|
||||
|
||||
# Image classification docstring
|
||||
_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
|
||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
|
||||
|
||||
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"nvidia/segformer-b0-finetuned-ade-512-512",
|
||||
# See all SegFormer models at https://huggingface.co/models?filter=segformer
|
||||
]
|
||||
|
||||
|
||||
# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->Segformer
|
||||
class TFSegformerDropPath(tf.keras.layers.Layer):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
References:
|
||||
(1) github.com:rwightman/pytorch-image-models
|
||||
"""
|
||||
|
||||
def __init__(self, drop_path, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.drop_path = drop_path
|
||||
|
||||
def call(self, x, training=None):
|
||||
if training:
|
||||
keep_prob = 1 - self.drop_path
|
||||
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
|
||||
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
|
||||
random_tensor = tf.floor(random_tensor)
|
||||
return (x / keep_prob) * random_tensor
|
||||
return x
|
||||
|
||||
|
||||
class TFSegformerOverlapPatchEmbeddings(tf.keras.layers.Layer):
|
||||
"""Construct the overlapping patch embeddings."""
|
||||
|
||||
def __init__(self, patch_size, stride, hidden_size, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.padding = tf.keras.layers.ZeroPadding2D(padding=patch_size // 2)
|
||||
self.proj = tf.keras.layers.Conv2D(
|
||||
filters=hidden_size, kernel_size=patch_size, strides=stride, padding="VALID", name="proj"
|
||||
)
|
||||
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
|
||||
|
||||
def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]:
|
||||
embeddings = self.proj(self.padding(pixel_values))
|
||||
height = shape_list(embeddings)[1]
|
||||
width = shape_list(embeddings)[2]
|
||||
hidden_dim = shape_list(embeddings)[3]
|
||||
# (batch_size, height, width, num_channels) -> (batch_size, height*width, num_channels)
|
||||
# this can be fed to a Transformer layer
|
||||
embeddings = tf.reshape(embeddings, (-1, height * width, hidden_dim))
|
||||
embeddings = self.layer_norm(embeddings)
|
||||
return embeddings, height, width
|
||||
|
||||
|
||||
class TFSegformerEfficientSelfAttention(tf.keras.layers.Layer):
|
||||
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
|
||||
paper](https://arxiv.org/abs/2102.12122)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SegformerConfig,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
sequence_reduction_ratio: int,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({self.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.attention_head_size = self.hidden_size // self.num_attention_heads
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
||||
|
||||
self.query = tf.keras.layers.Dense(self.all_head_size, name="query")
|
||||
self.key = tf.keras.layers.Dense(self.all_head_size, name="key")
|
||||
self.value = tf.keras.layers.Dense(self.all_head_size, name="value")
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
self.sr_ratio = sequence_reduction_ratio
|
||||
if sequence_reduction_ratio > 1:
|
||||
self.sr = tf.keras.layers.Conv2D(
|
||||
filters=hidden_size, kernel_size=sequence_reduction_ratio, strides=sequence_reduction_ratio, name="sr"
|
||||
)
|
||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
|
||||
|
||||
def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
|
||||
# Reshape from [batch_size, seq_length, all_head_size]
|
||||
# to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||
batch_size = shape_list(tensor)[0]
|
||||
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||
# to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
||||
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
output_attentions: bool = False,
|
||||
training: bool = False,
|
||||
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
num_channels = shape_list(hidden_states)[2]
|
||||
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
if self.sr_ratio > 1:
|
||||
# Reshape to (batch_size, height, width, num_channels)
|
||||
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
|
||||
# Apply sequence reduction
|
||||
hidden_states = self.sr(hidden_states)
|
||||
# Reshape back to (batch_size, seq_len, num_channels)
|
||||
hidden_states = tf.reshape(hidden_states, (batch_size, -1, num_channels))
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
||||
|
||||
scale = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
||||
attention_scores = tf.divide(attention_scores, scale)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs, training=training)
|
||||
|
||||
context_layer = tf.matmul(attention_probs, value_layer)
|
||||
|
||||
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
||||
# (batch_size, seq_len_q, all_head_size)
|
||||
context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size))
|
||||
|
||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||
return outputs
|
||||
|
||||
|
||||
class TFSegformerSelfOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFSegformerAttention(tf.keras.layers.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
config: SegformerConfig,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
sequence_reduction_ratio: int,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.self = TFSegformerEfficientSelfAttention(
|
||||
config=config,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
sequence_reduction_ratio=sequence_reduction_ratio,
|
||||
name="self",
|
||||
)
|
||||
self.dense_output = TFSegformerSelfOutput(config, hidden_size=hidden_size, name="output")
|
||||
|
||||
def call(
|
||||
self, hidden_states: tf.Tensor, height: int, width: int, output_attentions: bool = False
|
||||
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
|
||||
self_outputs = self.self(hidden_states, height, width, output_attentions)
|
||||
|
||||
attention_output = self.dense_output(self_outputs[0])
|
||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class TFSegformerDWConv(tf.keras.layers.Layer):
|
||||
def __init__(self, dim: int = 768, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.depthwise_convolution = tf.keras.layers.Conv2D(
|
||||
filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, name="dwconv"
|
||||
)
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
num_channels = shape_list(hidden_states)[-1]
|
||||
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
|
||||
hidden_states = self.depthwise_convolution(hidden_states)
|
||||
|
||||
new_height = shape_list(hidden_states)[1]
|
||||
new_width = shape_list(hidden_states)[2]
|
||||
num_channels = shape_list(hidden_states)[3]
|
||||
hidden_states = tf.reshape(hidden_states, (batch_size, new_height * new_width, num_channels))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFSegformerMixFFN(tf.keras.layers.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
config: SegformerConfig,
|
||||
in_features: int,
|
||||
hidden_features: int = None,
|
||||
out_features: int = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
out_features = out_features or in_features
|
||||
self.dense1 = tf.keras.layers.Dense(hidden_features, name="dense1")
|
||||
self.depthwise_convolution = TFSegformerDWConv(hidden_features, name="dwconv")
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
self.dense2 = tf.keras.layers.Dense(out_features, name="dense2")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
|
||||
hidden_states = self.dense1(hidden_states)
|
||||
hidden_states = self.depthwise_convolution(hidden_states, height, width)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.dense2(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFSegformerLayer(tf.keras.layers.Layer):
|
||||
"""This corresponds to the Block class in the original implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
drop_path: float,
|
||||
sequence_reduction_ratio: int,
|
||||
mlp_ratio: int,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.layer_norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_1")
|
||||
self.attention = TFSegformerAttention(
|
||||
config,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
sequence_reduction_ratio=sequence_reduction_ratio,
|
||||
name="attention",
|
||||
)
|
||||
self.drop_path = TFSegformerDropPath(drop_path) if drop_path > 0.0 else tf.keras.layers.Activation("linear")
|
||||
self.layer_norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_2")
|
||||
mlp_hidden_size = int(hidden_size * mlp_ratio)
|
||||
self.mlp = TFSegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size, name="mlp")
|
||||
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
output_attentions: bool = False,
|
||||
training: bool = False,
|
||||
) -> Tuple:
|
||||
self_attention_outputs = self.attention(
|
||||
self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
|
||||
height,
|
||||
width,
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
# first residual connection (with stochastic depth)
|
||||
attention_output = self.drop_path(attention_output, training=training)
|
||||
hidden_states = attention_output + hidden_states
|
||||
mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
|
||||
|
||||
# second residual connection (with stochastic depth)
|
||||
mlp_output = self.drop_path(mlp_output, training=training)
|
||||
layer_output = mlp_output + hidden_states
|
||||
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TFSegformerEncoder(tf.keras.layers.Layer):
|
||||
def __init__(self, config: SegformerConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.config = config
|
||||
|
||||
# stochastic depth decay rule
|
||||
drop_path_decays = [x.numpy() for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
|
||||
|
||||
# patch embeddings
|
||||
embeddings = []
|
||||
for i in range(config.num_encoder_blocks):
|
||||
embeddings.append(
|
||||
TFSegformerOverlapPatchEmbeddings(
|
||||
patch_size=config.patch_sizes[i],
|
||||
stride=config.strides[i],
|
||||
hidden_size=config.hidden_sizes[i],
|
||||
name=f"patch_embeddings.{i}",
|
||||
)
|
||||
)
|
||||
self.embeddings = embeddings
|
||||
|
||||
# Transformer blocks
|
||||
blocks = []
|
||||
cur = 0
|
||||
for i in range(config.num_encoder_blocks):
|
||||
# each block consists of layers
|
||||
layers = []
|
||||
if i != 0:
|
||||
cur += config.depths[i - 1]
|
||||
for j in range(config.depths[i]):
|
||||
layers.append(
|
||||
TFSegformerLayer(
|
||||
config,
|
||||
hidden_size=config.hidden_sizes[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
drop_path=drop_path_decays[cur + j],
|
||||
sequence_reduction_ratio=config.sr_ratios[i],
|
||||
mlp_ratio=config.mlp_ratios[i],
|
||||
name=f"block.{i}.{j}",
|
||||
)
|
||||
)
|
||||
blocks.append(layers)
|
||||
|
||||
self.block = blocks
|
||||
|
||||
# Layer norms
|
||||
self.layer_norms = [
|
||||
tf.keras.layers.LayerNormalization(epsilon=1e-05, name=f"layer_norm.{i}")
|
||||
for i in range(config.num_encoder_blocks)
|
||||
]
|
||||
|
||||
def call(
|
||||
self,
|
||||
pixel_values: tf.Tensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
training: bool = False,
|
||||
) -> Union[Tuple, TFBaseModelOutput]:
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
batch_size = shape_list(pixel_values)[0]
|
||||
|
||||
hidden_states = pixel_values
|
||||
for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)):
|
||||
embedding_layer, block_layer, norm_layer = x
|
||||
# first, obtain patch embeddings
|
||||
hidden_states, height, width = embedding_layer(hidden_states)
|
||||
|
||||
# second, send embeddings through blocks
|
||||
# (each block consists of multiple layers i.e., list of layers)
|
||||
for i, blk in enumerate(block_layer):
|
||||
layer_outputs = blk(
|
||||
hidden_states,
|
||||
height,
|
||||
width,
|
||||
output_attentions,
|
||||
training=training,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
# third, apply layer norm
|
||||
hidden_states = norm_layer(hidden_states)
|
||||
|
||||
# fourth, optionally reshape back to (batch_size, height, width, num_channels)
|
||||
if idx != len(self.embeddings) - 1 or (idx == len(self.embeddings) - 1 and self.config.reshape_last_stage):
|
||||
num_channels = shape_list(hidden_states)[-1]
|
||||
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
|
||||
)
|
||||
|
||||
|
||||
@keras_serializable
|
||||
class TFSegformerMainLayer(tf.keras.layers.Layer):
|
||||
config_class = SegformerConfig
|
||||
|
||||
def __init__(self, config: SegformerConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
# hierarchical Transformer encoder
|
||||
self.encoder = TFSegformerEncoder(config, name="encoder")
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
pixel_values: tf.Tensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
) -> Union[Tuple, TFBaseModelOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
|
||||
# So change the input format from `NCHW` to `NHWC`.
|
||||
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
||||
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
# Change to NCHW output format to have uniformity in the modules
|
||||
sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2])
|
||||
|
||||
# Change the other hidden state outputs to NCHW as well
|
||||
if output_hidden_states:
|
||||
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
|
||||
|
||||
if not return_dict:
|
||||
if tf.greater(len(encoder_outputs[1:]), 0):
|
||||
transposed_encoder_outputs = tuple(tf.transpose(v, perm=[0, 3, 1, 2]) for v in encoder_outputs[1:][0])
|
||||
return (sequence_output,) + (transposed_encoder_outputs,)
|
||||
else:
|
||||
return (sequence_output,) + encoder_outputs[1:]
|
||||
|
||||
return TFBaseModelOutput(
|
||||
last_hidden_state=sequence_output,
|
||||
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class TFSegformerPreTrainedModel(TFPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = SegformerConfig
|
||||
base_model_prefix = "segformer"
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
"""
|
||||
Dummy inputs to build the network.
|
||||
|
||||
Returns:
|
||||
`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||
"""
|
||||
VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 512, 512), dtype=tf.float32)
|
||||
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
|
||||
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
{
|
||||
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
|
||||
}
|
||||
]
|
||||
)
|
||||
def serving(self, inputs):
|
||||
"""
|
||||
Method used for serving the model.
|
||||
|
||||
Args:
|
||||
inputs (`Dict[str, tf.Tensor]`):
|
||||
The input of the saved model as a dictionary of tensors.
|
||||
"""
|
||||
return self.call(inputs)
|
||||
|
||||
|
||||
SEGFORMER_START_DOCSTRING = r"""
|
||||
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
||||
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
||||
behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`SegformerConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
SEGFORMER_INPUTS_DOCSTRING = r"""
|
||||
|
||||
Args:
|
||||
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
|
||||
[`AutoFeatureExtractor.__call__`] for details.
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
||||
config will be used instead.
|
||||
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
||||
used instead.
|
||||
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
||||
eager mode, in graph mode the value will always be set to True.
|
||||
|
||||
training (`bool`, *optional*, defaults to `False``):
|
||||
Whether or not to use the model in training mode (some modules like dropout modules have different
|
||||
behaviors between training and evaluation).
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
|
||||
SEGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFSegformerModel(TFSegformerPreTrainedModel):
|
||||
def __init__(self, config: SegformerConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.config = config
|
||||
|
||||
# hierarchical Transformer encoder
|
||||
self.segformer = TFSegformerMainLayer(config, name="segformer")
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TFBaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="vision",
|
||||
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: tf.Tensor,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
training: bool = False,
|
||||
) -> Union[Tuple, TFBaseModelOutput]:
|
||||
outputs = self.segformer(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
training=training,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden
|
||||
states) e.g. for ImageNet.
|
||||
""",
|
||||
SEGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceClassificationLoss):
|
||||
def __init__(self, config: SegformerConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.segformer = TFSegformerMainLayer(config, name="segformer")
|
||||
|
||||
# Classifier head
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||
output_type=TFSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: Optional[tf.Tensor] = None,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, TFSequenceClassifierOutput]:
|
||||
outputs = self.segformer(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
# convert last hidden states to (batch_size, height*width, hidden_size)
|
||||
batch_size = shape_list(sequence_output)[0]
|
||||
sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1])
|
||||
sequence_output = tf.reshape(sequence_output, (batch_size, -1, self.config.hidden_sizes[-1]))
|
||||
|
||||
# global average pooling
|
||||
sequence_output = tf.reduce_mean(sequence_output, axis=1)
|
||||
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSequenceClassifierOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||
)
|
||||
|
||||
|
||||
class TFSegformerMLP(tf.keras.layers.Layer):
|
||||
"""
|
||||
Linear Embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SegformerConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.proj = tf.keras.layers.Dense(config.decoder_hidden_size, name="proj")
|
||||
|
||||
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
||||
height = shape_list(hidden_states)[1]
|
||||
width = shape_list(hidden_states)[2]
|
||||
hidden_dim = shape_list(hidden_states)[-1]
|
||||
hidden_states = tf.reshape(hidden_states, (-1, height * width, hidden_dim))
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
|
||||
def __init__(self, config: SegformerConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
|
||||
mlps = []
|
||||
for i in range(config.num_encoder_blocks):
|
||||
mlp = TFSegformerMLP(config, name=f"linear_c.{i}")
|
||||
mlps.append(mlp)
|
||||
self.mlps = mlps
|
||||
|
||||
# the following 3 layers implement the ConvModule of the original implementation
|
||||
self.linear_fuse = tf.keras.layers.Conv2D(
|
||||
filters=config.decoder_hidden_size, kernel_size=1, use_bias=False, name="linear_fuse"
|
||||
)
|
||||
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="batch_norm")
|
||||
self.activation = tf.keras.layers.Activation("relu")
|
||||
|
||||
self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob)
|
||||
self.classifier = tf.keras.layers.Conv2D(filters=config.num_labels, kernel_size=1, name="classifier")
|
||||
|
||||
self.config = config
|
||||
|
||||
def call(self, encoder_hidden_states):
|
||||
batch_size = shape_list(encoder_hidden_states[-1])[0]
|
||||
|
||||
all_hidden_states = ()
|
||||
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
|
||||
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
|
||||
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
|
||||
height = width = tf.cast(height, tf.int32)
|
||||
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
|
||||
|
||||
# unify channel dimension
|
||||
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
|
||||
height = shape_list(encoder_hidden_state)[1]
|
||||
width = shape_list(encoder_hidden_state)[2]
|
||||
encoder_hidden_state = mlp(encoder_hidden_state)
|
||||
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
|
||||
|
||||
# upsample
|
||||
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
|
||||
upsample_resolution = shape_list(temp_state)[1:-1]
|
||||
encoder_hidden_state = tf.image.resize(encoder_hidden_state, size=upsample_resolution, method="bilinear")
|
||||
all_hidden_states += (encoder_hidden_state,)
|
||||
|
||||
hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1))
|
||||
hidden_states = self.batch_norm(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
# logits of shape (batch_size, height/4, width/4, num_labels)
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""",
|
||||
SEGFORMER_START_DOCSTRING,
|
||||
)
|
||||
class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel):
|
||||
def __init__(self, config: SegformerConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.segformer = TFSegformerMainLayer(config, name="segformer")
|
||||
self.decode_head = TFSegformerDecodeHead(config, name="decode_head")
|
||||
|
||||
def hf_compute_loss(self, logits, labels):
|
||||
# upsample logits to the images' original size
|
||||
# `labels` is of shape (batch_size, height, width)
|
||||
label_interp_shape = shape_list(labels)[1:]
|
||||
|
||||
upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
|
||||
# compute weighted loss
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
|
||||
|
||||
def masked_loss(real, pred):
|
||||
unmasked_loss = loss_fct(real, pred)
|
||||
mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
|
||||
masked_loss = unmasked_loss * mask
|
||||
# Reduction strategy in the similar spirit with
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
|
||||
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
|
||||
return tf.reshape(reduced_masked_loss, (1,))
|
||||
|
||||
return masked_loss(labels, upsampled_logits)
|
||||
|
||||
@unpack_inputs
|
||||
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def call(
|
||||
self,
|
||||
pixel_values: tf.Tensor,
|
||||
labels: Optional[tf.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, TFSemanticSegmenterOutput]:
|
||||
r"""
|
||||
labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
|
||||
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels > 1`, a (per-pixel) classification loss is computed
|
||||
(Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
||||
>>> model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
>>> outputs = model(**inputs, training=False)
|
||||
>>> # logits are of shape (batch_size, num_labels, height, width)
|
||||
>>> logits = outputs.logits
|
||||
>>> logits.shape
|
||||
(1, 150, 128, 128)
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
outputs = self.segformer(
|
||||
pixel_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True, # we need the intermediate hidden states
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
||||
|
||||
logits = self.decode_head(encoder_hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if not self.config.num_labels > 1:
|
||||
raise ValueError("The number of labels should be greater than one")
|
||||
else:
|
||||
loss = self.hf_compute_loss(logits=logits, labels=labels)
|
||||
|
||||
# make logits of shape (batch_size, num_labels, height, width) to
|
||||
# keep them consistent across APIs
|
||||
logits = tf.transpose(logits, perm=[0, 3, 1, 2])
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
output = (logits,) + outputs[1:]
|
||||
else:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TFSemanticSegmenterOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
@ -1980,6 +1980,44 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFSegformerDecodeHead(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSegformerForImageClassification(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSegformerForSemanticSegmentation(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSegformerModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSegformerPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers import SegformerConfig, is_torch_available, is_vision_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
@ -31,7 +31,6 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
MODEL_MAPPING,
|
||||
SegformerConfig,
|
||||
SegformerForImageClassification,
|
||||
SegformerForSemanticSegmentation,
|
||||
SegformerModel,
|
||||
|
540
tests/models/segformer/test_modeling_tf_segformer.py
Normal file
540
tests/models/segformer/test_modeling_tf_segformer.py
Normal file
@ -0,0 +1,540 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow SegFormer model. """
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import SegformerConfig
|
||||
from transformers.file_utils import is_tf_available, is_vision_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, TFSegformerModel
|
||||
from transformers.models.segformer.modeling_tf_segformer import TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
from transformers import SegformerFeatureExtractor
|
||||
|
||||
|
||||
class TFSegformerConfigTester(ConfigTester):
|
||||
def create_and_test_config_common_properties(self):
|
||||
config = self.config_class(**self.inputs_dict)
|
||||
self.parent.assertTrue(hasattr(config, "hidden_sizes"))
|
||||
self.parent.assertTrue(hasattr(config, "num_attention_heads"))
|
||||
self.parent.assertTrue(hasattr(config, "num_encoder_blocks"))
|
||||
|
||||
|
||||
class TFSegformerModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
image_size=64,
|
||||
num_channels=3,
|
||||
num_encoder_blocks=4,
|
||||
depths=[2, 2, 2, 2],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
hidden_sizes=[16, 32, 64, 128],
|
||||
downsampling_rates=[1, 4, 8, 16],
|
||||
num_attention_heads=[1, 2, 4, 8],
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.num_channels = num_channels
|
||||
self.num_encoder_blocks = num_encoder_blocks
|
||||
self.sr_ratios = sr_ratios
|
||||
self.depths = depths
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.downsampling_rates = downsampling_rates
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
return config, pixel_values, labels
|
||||
|
||||
def get_config(self):
|
||||
return SegformerConfig(
|
||||
image_size=self.image_size,
|
||||
num_channels=self.num_channels,
|
||||
num_encoder_blocks=self.num_encoder_blocks,
|
||||
depths=self.depths,
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
initializer_range=self.initializer_range,
|
||||
num_labels=self.num_labels,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
model = TFSegformerModel(config=config)
|
||||
result = model(pixel_values, training=False)
|
||||
expected_height = expected_width = self.image_size // (self.downsampling_rates[-1] * 2)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.hidden_sizes[-1], expected_height, expected_width)
|
||||
)
|
||||
|
||||
def create_and_check_for_image_segmentation(self, config, pixel_values, labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFSegformerForSemanticSegmentation(config)
|
||||
result = model(pixel_values, training=False)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
||||
)
|
||||
result = model(pixel_values, labels=labels, training=False)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_keras_fit(self, for_segmentation: bool = False):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, seg_labels = config_and_inputs
|
||||
if for_segmentation:
|
||||
inputs_dict = {"pixel_values": pixel_values, "labels": seg_labels}
|
||||
else:
|
||||
inputs_dict = {"pixel_values": pixel_values, "labels": tf.zeros((self.batch_size))}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFSegformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(TFSegformerModel, TFSegformerForImageClassification, TFSegformerForSemanticSegmentation)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFSegformerModelTester(self)
|
||||
self.config_tester = TFSegformerConfigTester(self, config_class=SegformerConfig, has_text_modality=False)
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip("SegFormer does not use inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("SegFormer does not have get_input_embeddings method and get_output_embeddings methods")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test was written for TF 1.x and isn't really relevant here")
|
||||
def test_compile_tf_model(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
|
||||
expected_num_attentions = sum(self.model_tester.depths)
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(attentions), expected_num_attentions)
|
||||
|
||||
# verify the first attentions (first block, first layer)
|
||||
expected_seq_len = (self.model_tester.image_size // 4) ** 2
|
||||
expected_reduced_seq_len = (self.model_tester.image_size // (4 * self.model_tester.sr_ratios[0])) ** 2
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads[0], expected_seq_len, expected_reduced_seq_len],
|
||||
)
|
||||
|
||||
# verify the last attentions (last block, last layer)
|
||||
expected_seq_len = (self.model_tester.image_size // 32) ** 2
|
||||
expected_reduced_seq_len = (self.model_tester.image_size // (32 * self.model_tester.sr_ratios[-1])) ** 2
|
||||
self.assertListEqual(
|
||||
list(attentions[-1].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads[-1], expected_seq_len, expected_reduced_seq_len],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assertEqual(out_len + 1, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), expected_num_attentions)
|
||||
# verify the first attentions (first block, first layer)
|
||||
expected_seq_len = (self.model_tester.image_size // 4) ** 2
|
||||
expected_reduced_seq_len = (self.model_tester.image_size // (4 * self.model_tester.sr_ratios[0])) ** 2
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads[0], expected_seq_len, expected_reduced_seq_len],
|
||||
)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.hidden_states
|
||||
|
||||
expected_num_layers = self.model_tester.num_encoder_blocks
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
# verify the first hidden states (first block)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.hidden_sizes[0],
|
||||
self.model_tester.image_size // 4,
|
||||
self.model_tester.image_size // 4,
|
||||
],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
all(tf.equal(tuple_object, dict_object)),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
|
||||
),
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
if self.has_attentions:
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||
|
||||
# todo: incorporate label support for semantic segmentation in `test_modeling_tf_common.py`.
|
||||
|
||||
def test_dataset_conversion(self):
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
# Grouped convs aren't supported on CPUs for backprop.
|
||||
if len(gpus) >= 1:
|
||||
super().test_dataset_conversion()
|
||||
|
||||
def test_keras_fit(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
|
||||
def apply(model):
|
||||
if getattr(model, "hf_compute_loss", None):
|
||||
model_weights = model.get_weights()
|
||||
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
for_segmentation = True if model_class.__name__ == "TFSegformerForSemanticSegmentation" else False
|
||||
_, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit(
|
||||
for_segmentation=for_segmentation
|
||||
)
|
||||
|
||||
label_names = {"labels"}
|
||||
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
|
||||
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
|
||||
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
|
||||
self.assertGreater(len(inputs_minus_labels), 0)
|
||||
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
|
||||
|
||||
# Make sure the model fits without crashing regardless of where we pass the labels
|
||||
history1 = model.fit(
|
||||
prepared_for_class,
|
||||
validation_data=prepared_for_class,
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss1 = history1.history["val_loss"][0]
|
||||
|
||||
# We reinitialize the model here even though our learning rate was zero
|
||||
# because BatchNorm updates weights by means other than gradient descent.
|
||||
model.set_weights(model_weights)
|
||||
history2 = model.fit(
|
||||
inputs_minus_labels,
|
||||
labels,
|
||||
validation_data=(inputs_minus_labels, labels),
|
||||
steps_per_epoch=1,
|
||||
validation_steps=1,
|
||||
shuffle=False,
|
||||
)
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# Since `TFSegformerModel` cannot operate with the default `fit()` method.
|
||||
if model_class.__name__ != "TFSegformerModel":
|
||||
# Grouped convs and backprop with them isn't supported on CPUs.
|
||||
model = model_class(config)
|
||||
if len(gpus) > 1:
|
||||
apply(model)
|
||||
|
||||
def test_loss_computation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def apply(model):
|
||||
for_segmentation = True if model_class.__name__ == "TFSegformerForSemanticSegmentation" else False
|
||||
# The number of elements in the loss should be the same as the number of elements in the label
|
||||
_, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit(
|
||||
for_segmentation=for_segmentation
|
||||
)
|
||||
added_label = prepared_for_class[
|
||||
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
|
||||
]
|
||||
loss_size = tf.size(added_label)
|
||||
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
possible_input_names = {"input_ids", "pixel_values", "input_features"}
|
||||
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
|
||||
model_input = prepared_for_class.pop(input_name)
|
||||
|
||||
loss = model(model_input, **prepared_for_class)[0]
|
||||
|
||||
if model_class.__name__ == "TFSegformerForSemanticSegmentation":
|
||||
# Semantic segmentation loss is computed similarly as
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210.
|
||||
self.assertEqual(loss.shape, (1,))
|
||||
else:
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
# Test that model correctly compute the loss with a dict
|
||||
_, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit(
|
||||
for_segmentation=for_segmentation
|
||||
)
|
||||
loss = model(**prepared_for_class)[0]
|
||||
|
||||
if model_class.__name__ == "TFSegformerForSemanticSegmentation":
|
||||
self.assertEqual(loss.shape, (1,))
|
||||
else:
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
# Test that model correctly compute the loss with a tuple
|
||||
label_keys = prepared_for_class.keys() - inputs_dict.keys()
|
||||
signature = inspect.signature(model.call).parameters
|
||||
signature_names = list(signature.keys())
|
||||
|
||||
# Create a dictionary holding the location of the tensors in the tuple
|
||||
tuple_index_mapping = {0: input_name}
|
||||
for label_key in label_keys:
|
||||
label_key_index = signature_names.index(label_key)
|
||||
tuple_index_mapping[label_key_index] = label_key
|
||||
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
|
||||
# Initialize a list with their default values, update the values and convert to a tuple
|
||||
list_input = []
|
||||
|
||||
for name in signature_names:
|
||||
if name != "kwargs":
|
||||
list_input.append(signature[name].default)
|
||||
|
||||
for index, value in sorted_tuple_index_mapping:
|
||||
list_input[index] = prepared_for_class[value]
|
||||
|
||||
tuple_input = tuple(list_input)
|
||||
|
||||
# Send to model
|
||||
loss = model(tuple_input[:-1])[0]
|
||||
if model_class.__name__ == "TFSegformerForSemanticSegmentation":
|
||||
self.assertEqual(loss.shape, (1,))
|
||||
else:
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# Since `TFSegformerModel` won't have labels against which we
|
||||
# could compute loss.
|
||||
if model_class.__name__ != "TFSegformerModel":
|
||||
model = model_class(config)
|
||||
apply(model)
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFSegformerModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||
return image
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFSegformerModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_image_segmentation_ade(self):
|
||||
# only resize + normalize
|
||||
feature_extractor = SegformerFeatureExtractor(
|
||||
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
|
||||
)
|
||||
model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
||||
|
||||
image = prepare_img()
|
||||
encoded_inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
pixel_values = encoded_inputs.pixel_values
|
||||
|
||||
outputs = model(pixel_values, training=False)
|
||||
|
||||
expected_shape = tf.TensorShape((1, model.config.num_labels, 128, 128))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
|
||||
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
|
||||
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
|
||||
]
|
||||
)
|
||||
tf.debugging.assert_near(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4)
|
||||
|
||||
@slow
|
||||
def test_inference_image_segmentation_city(self):
|
||||
# only resize + normalize
|
||||
feature_extractor = SegformerFeatureExtractor(
|
||||
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
|
||||
)
|
||||
model = TFSegformerForSemanticSegmentation.from_pretrained(
|
||||
"nvidia/segformer-b1-finetuned-cityscapes-1024-1024"
|
||||
)
|
||||
|
||||
image = prepare_img()
|
||||
encoded_inputs = feature_extractor(images=image, return_tensors="tf")
|
||||
pixel_values = encoded_inputs.pixel_values
|
||||
|
||||
outputs = model(pixel_values, training=False)
|
||||
|
||||
expected_shape = tf.TensorShape((1, model.config.num_labels, 128, 128))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = tf.constant(
|
||||
[
|
||||
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
|
||||
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
|
||||
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
|
||||
]
|
||||
)
|
||||
tf.debugging.assert_near(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1)
|
@ -98,6 +98,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
|
||||
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
|
||||
"OPTDecoderWrapper",
|
||||
"TFSegformerDecodeHead", # Not a regular model.
|
||||
]
|
||||
|
||||
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
|
||||
@ -137,6 +138,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"PerceiverForMultimodalAutoencoding",
|
||||
"PerceiverForOpticalFlow",
|
||||
"SegformerDecodeHead",
|
||||
"TFSegformerDecodeHead",
|
||||
"FlaxBeitForMaskedImageModeling",
|
||||
"PLBartEncoder",
|
||||
"PLBartDecoder",
|
||||
|
@ -64,6 +64,7 @@ src/transformers/models/sew_d/modeling_sew_d.py
|
||||
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
|
||||
src/transformers/models/speech_to_text/modeling_speech_to_text.py
|
||||
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
|
||||
src/transformers/models/segformer/modeling_tf_segformer.py
|
||||
src/transformers/models/swin/modeling_swin.py
|
||||
src/transformers/models/trocr/modeling_trocr.py
|
||||
src/transformers/models/unispeech/modeling_unispeech.py
|
||||
|
Loading…
Reference in New Issue
Block a user