[SegFormer] TensorFlow port (#17910)

* add: segformer utils and img. classification.

* add: segmentation layer.

* feat: working implementation of segformer.

* chore: remove unused variable.

* add test, remaining modifications.

* remove: unnecessary files.

* add: rest of the files.

Co-authored-by: matt <rocketknight1@gmail.com>

* chore: remove ModuleList comment.

* chore: apply make style.

* chore: apply make fixup-copies.

* add  to check_repo.py

* add decode head to IGNORE_NON_TESTED

* chore: run make style.

* chore: PR comments.

* chore: minor changes to model doc.

* tests: reduction across samples.

* add a note on the space.

* sort importats.

* fix: reduction in loss computation.

* chore: align loss function with that of NER.

* chore: correct utils/documentation_tests.txt

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* chore: simplify the interpolation of logits in loss computation.

* chore: return transposed logits when return_dict=False.

* chore: add link to the tf fine-tuning repo.

* address pr comments.

* address niels's comments.

* remove from_pt=True since tf weights are in.

* remove comment from pt model.

* address niels's comments.

Co-authored-by: matt <rocketknight1@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Sayak Paul 2022-07-21 22:52:37 +05:30 committed by GitHub
parent 2c5747edfe
commit 561b9a8c00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1554 additions and 10 deletions

View File

@ -278,7 +278,7 @@ Flax), PyTorch, and/or TensorFlow.
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
| SegFormer | ❌ | ❌ | ✅ | | ❌ |
| SegFormer | ❌ | ❌ | ✅ | | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ |

View File

@ -36,13 +36,14 @@ The figure below illustrates the architecture of SegFormer. Taken from the [orig
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/segformer_architecture.png"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/NVlabs/SegFormer).
This model was contributed by [nielsr](https://huggingface.co/nielsr). The TensorFlow version
of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code can be found [here](https://github.com/NVlabs/SegFormer).
Tips:
- SegFormer consists of a hierarchical Transformer encoder, and a lightweight all-MLP decode head.
- SegFormer consists of a hierarchical Transformer encoder, and a lightweight all-MLP decoder head.
[`SegformerModel`] is the hierarchical Transformer encoder (which in the paper is also referred to
as Mix Transformer or MiT). [`SegformerForSemanticSegmentation`] adds the all-MLP decode head on
as Mix Transformer or MiT). [`SegformerForSemanticSegmentation`] adds the all-MLP decoder head on
top to perform semantic segmentation of images. In addition, there's
[`SegformerForImageClassification`] which can be used to - you guessed it - classify images. The
authors of SegFormer first pre-trained the Transformer encoder on ImageNet-1k to classify images. Next, they throw
@ -51,6 +52,9 @@ Tips:
found on the [hub](https://huggingface.co/models?other=segformer).
- The quickest way to get started with SegFormer is by checking the [example notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/SegFormer) (which showcase both inference and
fine-tuning on custom data). One can also check out the [blog post](https://huggingface.co/blog/fine-tune-segformer) introducing SegFormer and illustrating how it can be fine-tuned on custom data.
- TensorFlow users should refer to [this repository](https://github.com/deep-diver/segformer-tf-transformers) that shows off-the-shelf inference and fine-tuning.
- One can also check out [this interactive demo on Hugging Face Spaces](https://huggingface.co/spaces/chansung/segformer-tf-transformers)
to try out a SegFormer model on custom images.
- SegFormer works on any input size, as it pads the input to be divisible by `config.patch_sizes`.
- One can use [`SegformerFeatureExtractor`] to prepare images and corresponding segmentation maps
for the model. Note that this feature extractor is fairly basic and does not include all data augmentations used in
@ -65,7 +69,8 @@ Tips:
used by [`SegformerForSemanticSegmentation`]). However, other datasets use the 0 index as
background class and include this class as part of all labels. In that case, `reduce_labels` should be set to
`False`, as loss should also be computed for the background class.
- As most models, SegFormer comes in different sizes, the details of which can be found in the table below.
- As most models, SegFormer comes in different sizes, the details of which can be found in the table below
(taken from Table 7 of the [original paper](https://arxiv.org/abs/2105.15203)).
| **Model variant** | **Depths** | **Hidden sizes** | **Decoder hidden size** | **Params (M)** | **ImageNet-1k Top 1** |
| :---------------: | ------------- | ------------------- | :---------------------: | :------------: | :-------------------: |
@ -76,6 +81,10 @@ Tips:
| MiT-b4 | [3, 8, 27, 3] | [64, 128, 320, 512] | 768 | 62.6 | 83.6 |
| MiT-b5 | [3, 6, 40, 3] | [64, 128, 320, 512] | 768 | 82.0 | 83.8 |
Note that MiT in the above table refers to the Mix Transformer encoder backbone introduced in SegFormer. For
SegFormer's results on the segmentation datasets like ADE20k, refer to the [paper](https://arxiv.org/abs/2105.15203).
## SegformerConfig
[[autodoc]] SegformerConfig
@ -104,3 +113,23 @@ Tips:
[[autodoc]] SegformerForSemanticSegmentation
- forward
## TFSegformerDecodeHead
[[autodoc]] TFSegformerDecodeHead
- call
## TFSegformerModel
[[autodoc]] TFSegformerModel
- call
## TFSegformerForImageClassification
[[autodoc]] TFSegformerForImageClassification
- call
## TFSegformerForSemanticSegmentation
[[autodoc]] TFSegformerForSemanticSegmentation
- call

View File

@ -2430,6 +2430,16 @@ else:
"TFRoFormerPreTrainedModel",
]
)
_import_structure["models.segformer"].extend(
[
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSegformerDecodeHead",
"TFSegformerForImageClassification",
"TFSegformerForSemanticSegmentation",
"TFSegformerModel",
"TFSegformerPreTrainedModel",
]
)
_import_structure["models.speech_to_text"].extend(
[
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -4789,6 +4799,14 @@ if TYPE_CHECKING:
TFRoFormerModel,
TFRoFormerPreTrainedModel,
)
from .models.segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,
TFSegformerForImageClassification,
TFSegformerForSemanticSegmentation,
TFSegformerModel,
TFSegformerPreTrainedModel,
)
from .models.speech_to_text import (
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSpeech2TextForConditionalGeneration,

View File

@ -68,6 +68,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("resnet", "TFResNetModel"),
("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
("t5", "TFT5Model"),
@ -180,6 +181,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
("segformer", "TFSegformerForImageClassification"),
("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"),
]
@ -189,6 +191,7 @@ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Semantic Segmentation mapping
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
("segformer", "TFSegformerForSemanticSegmentation"),
]
)

View File

@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]}
@ -46,6 +52,21 @@ else:
"SegformerPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_segformer"] = [
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSegformerDecodeHead",
"TFSegformerForImageClassification",
"TFSegformerForSemanticSegmentation",
"TFSegformerModel",
"TFSegformerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig
@ -73,7 +94,20 @@ if TYPE_CHECKING:
SegformerModel,
SegformerPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,
TFSegformerForImageClassification,
TFSegformerForSemanticSegmentation,
TFSegformerModel,
TFSegformerPreTrainedModel,
)
else:
import sys

View File

@ -785,6 +785,8 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits # shape (batch_size, num_labels, height, width)
>>> logits.shape
(1, 150, 128, 128)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
@ -804,7 +806,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
loss = None
if labels is not None:
if self.config.num_labels == 1:
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
# upsample logits to the images' original size

View File

@ -0,0 +1,878 @@
# coding=utf-8
# Copyright 2022 NVIDIA The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow SegFormer model."""
import math
from typing import Dict, Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFSemanticSegmenterOutput, TFSequenceClassifierOutput
from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_segformer import SegformerConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "SegformerConfig"
_FEAT_EXTRACTOR_FOR_DOC = "SegformerFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "nvidia/mit-b0"
_EXPECTED_OUTPUT_SHAPE = [1, 256, 16, 16]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "nvidia/mit-b0"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"nvidia/segformer-b0-finetuned-ade-512-512",
# See all SegFormer models at https://huggingface.co/models?filter=segformer
]
# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->Segformer
class TFSegformerDropPath(tf.keras.layers.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
References:
(1) github.com:rwightman/pytorch-image-models
"""
def __init__(self, drop_path, **kwargs):
super().__init__(**kwargs)
self.drop_path = drop_path
def call(self, x, training=None):
if training:
keep_prob = 1 - self.drop_path
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
random_tensor = tf.floor(random_tensor)
return (x / keep_prob) * random_tensor
return x
class TFSegformerOverlapPatchEmbeddings(tf.keras.layers.Layer):
"""Construct the overlapping patch embeddings."""
def __init__(self, patch_size, stride, hidden_size, **kwargs):
super().__init__(**kwargs)
self.padding = tf.keras.layers.ZeroPadding2D(padding=patch_size // 2)
self.proj = tf.keras.layers.Conv2D(
filters=hidden_size, kernel_size=patch_size, strides=stride, padding="VALID", name="proj"
)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
def call(self, pixel_values: tf.Tensor) -> Tuple[tf.Tensor, int, int]:
embeddings = self.proj(self.padding(pixel_values))
height = shape_list(embeddings)[1]
width = shape_list(embeddings)[2]
hidden_dim = shape_list(embeddings)[3]
# (batch_size, height, width, num_channels) -> (batch_size, height*width, num_channels)
# this can be fed to a Transformer layer
embeddings = tf.reshape(embeddings, (-1, height * width, hidden_dim))
embeddings = self.layer_norm(embeddings)
return embeddings, height, width
class TFSegformerEfficientSelfAttention(tf.keras.layers.Layer):
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
paper](https://arxiv.org/abs/2102.12122)."""
def __init__(
self,
config: SegformerConfig,
hidden_size: int,
num_attention_heads: int,
sequence_reduction_ratio: int,
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
f"heads ({self.num_attention_heads})"
)
self.attention_head_size = self.hidden_size // self.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.Dense(self.all_head_size, name="query")
self.key = tf.keras.layers.Dense(self.all_head_size, name="key")
self.value = tf.keras.layers.Dense(self.all_head_size, name="value")
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
self.sr_ratio = sequence_reduction_ratio
if sequence_reduction_ratio > 1:
self.sr = tf.keras.layers.Conv2D(
filters=hidden_size, kernel_size=sequence_reduction_ratio, strides=sequence_reduction_ratio, name="sr"
)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm")
def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size]
# to [batch_size, seq_length, num_attention_heads, attention_head_size]
batch_size = shape_list(tensor)[0]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size]
# to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
height: int,
width: int,
output_attentions: bool = False,
training: bool = False,
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
batch_size = shape_list(hidden_states)[0]
num_channels = shape_list(hidden_states)[2]
query_layer = self.transpose_for_scores(self.query(hidden_states))
if self.sr_ratio > 1:
# Reshape to (batch_size, height, width, num_channels)
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
# Apply sequence reduction
hidden_states = self.sr(hidden_states)
# Reshape back to (batch_size, seq_len, num_channels)
hidden_states = tf.reshape(hidden_states, (batch_size, -1, num_channels))
hidden_states = self.layer_norm(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
scale = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, scale)
# Normalize the attention scores to probabilities.
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
context_layer = tf.reshape(context_layer, (batch_size, -1, self.all_head_size))
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class TFSegformerSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: SegformerConfig, hidden_size: int, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class TFSegformerAttention(tf.keras.layers.Layer):
def __init__(
self,
config: SegformerConfig,
hidden_size: int,
num_attention_heads: int,
sequence_reduction_ratio: int,
**kwargs
):
super().__init__(**kwargs)
self.self = TFSegformerEfficientSelfAttention(
config=config,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
sequence_reduction_ratio=sequence_reduction_ratio,
name="self",
)
self.dense_output = TFSegformerSelfOutput(config, hidden_size=hidden_size, name="output")
def call(
self, hidden_states: tf.Tensor, height: int, width: int, output_attentions: bool = False
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
self_outputs = self.self(hidden_states, height, width, output_attentions)
attention_output = self.dense_output(self_outputs[0])
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class TFSegformerDWConv(tf.keras.layers.Layer):
def __init__(self, dim: int = 768, **kwargs):
super().__init__(**kwargs)
self.depthwise_convolution = tf.keras.layers.Conv2D(
filters=dim, kernel_size=3, strides=1, padding="same", groups=dim, name="dwconv"
)
def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
batch_size = shape_list(hidden_states)[0]
num_channels = shape_list(hidden_states)[-1]
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
hidden_states = self.depthwise_convolution(hidden_states)
new_height = shape_list(hidden_states)[1]
new_width = shape_list(hidden_states)[2]
num_channels = shape_list(hidden_states)[3]
hidden_states = tf.reshape(hidden_states, (batch_size, new_height * new_width, num_channels))
return hidden_states
class TFSegformerMixFFN(tf.keras.layers.Layer):
def __init__(
self,
config: SegformerConfig,
in_features: int,
hidden_features: int = None,
out_features: int = None,
**kwargs
):
super().__init__(**kwargs)
out_features = out_features or in_features
self.dense1 = tf.keras.layers.Dense(hidden_features, name="dense1")
self.depthwise_convolution = TFSegformerDWConv(hidden_features, name="dwconv")
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
self.dense2 = tf.keras.layers.Dense(out_features, name="dense2")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, height: int, width: int) -> tf.Tensor:
hidden_states = self.dense1(hidden_states)
hidden_states = self.depthwise_convolution(hidden_states, height, width)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense2(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class TFSegformerLayer(tf.keras.layers.Layer):
"""This corresponds to the Block class in the original implementation."""
def __init__(
self,
config,
hidden_size: int,
num_attention_heads: int,
drop_path: float,
sequence_reduction_ratio: int,
mlp_ratio: int,
**kwargs
):
super().__init__(**kwargs)
self.layer_norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_1")
self.attention = TFSegformerAttention(
config,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
sequence_reduction_ratio=sequence_reduction_ratio,
name="attention",
)
self.drop_path = TFSegformerDropPath(drop_path) if drop_path > 0.0 else tf.keras.layers.Activation("linear")
self.layer_norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-05, name="layer_norm_2")
mlp_hidden_size = int(hidden_size * mlp_ratio)
self.mlp = TFSegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size, name="mlp")
def call(
self,
hidden_states: tf.Tensor,
height: int,
width: int,
output_attentions: bool = False,
training: bool = False,
) -> Tuple:
self_attention_outputs = self.attention(
self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention
height,
width,
output_attentions=output_attentions,
training=training,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
# first residual connection (with stochastic depth)
attention_output = self.drop_path(attention_output, training=training)
hidden_states = attention_output + hidden_states
mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
# second residual connection (with stochastic depth)
mlp_output = self.drop_path(mlp_output, training=training)
layer_output = mlp_output + hidden_states
outputs = (layer_output,) + outputs
return outputs
class TFSegformerEncoder(tf.keras.layers.Layer):
def __init__(self, config: SegformerConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
# stochastic depth decay rule
drop_path_decays = [x.numpy() for x in tf.linspace(0.0, config.drop_path_rate, sum(config.depths))]
# patch embeddings
embeddings = []
for i in range(config.num_encoder_blocks):
embeddings.append(
TFSegformerOverlapPatchEmbeddings(
patch_size=config.patch_sizes[i],
stride=config.strides[i],
hidden_size=config.hidden_sizes[i],
name=f"patch_embeddings.{i}",
)
)
self.embeddings = embeddings
# Transformer blocks
blocks = []
cur = 0
for i in range(config.num_encoder_blocks):
# each block consists of layers
layers = []
if i != 0:
cur += config.depths[i - 1]
for j in range(config.depths[i]):
layers.append(
TFSegformerLayer(
config,
hidden_size=config.hidden_sizes[i],
num_attention_heads=config.num_attention_heads[i],
drop_path=drop_path_decays[cur + j],
sequence_reduction_ratio=config.sr_ratios[i],
mlp_ratio=config.mlp_ratios[i],
name=f"block.{i}.{j}",
)
)
blocks.append(layers)
self.block = blocks
# Layer norms
self.layer_norms = [
tf.keras.layers.LayerNormalization(epsilon=1e-05, name=f"layer_norm.{i}")
for i in range(config.num_encoder_blocks)
]
def call(
self,
pixel_values: tf.Tensor,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
training: bool = False,
) -> Union[Tuple, TFBaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
batch_size = shape_list(pixel_values)[0]
hidden_states = pixel_values
for idx, x in enumerate(zip(self.embeddings, self.block, self.layer_norms)):
embedding_layer, block_layer, norm_layer = x
# first, obtain patch embeddings
hidden_states, height, width = embedding_layer(hidden_states)
# second, send embeddings through blocks
# (each block consists of multiple layers i.e., list of layers)
for i, blk in enumerate(block_layer):
layer_outputs = blk(
hidden_states,
height,
width,
output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
# third, apply layer norm
hidden_states = norm_layer(hidden_states)
# fourth, optionally reshape back to (batch_size, height, width, num_channels)
if idx != len(self.embeddings) - 1 or (idx == len(self.embeddings) - 1 and self.config.reshape_last_stage):
num_channels = shape_list(hidden_states)[-1]
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, num_channels))
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
)
@keras_serializable
class TFSegformerMainLayer(tf.keras.layers.Layer):
config_class = SegformerConfig
def __init__(self, config: SegformerConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
# hierarchical Transformer encoder
self.encoder = TFSegformerEncoder(config, name="encoder")
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple, TFBaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
encoder_outputs = self.encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
# Change to NCHW output format to have uniformity in the modules
sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2])
# Change the other hidden state outputs to NCHW as well
if output_hidden_states:
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
if not return_dict:
if tf.greater(len(encoder_outputs[1:]), 0):
transposed_encoder_outputs = tuple(tf.transpose(v, perm=[0, 3, 1, 2]) for v in encoder_outputs[1:][0])
return (sequence_output,) + (transposed_encoder_outputs,)
else:
return (sequence_output,) + encoder_outputs[1:]
return TFBaseModelOutput(
last_hidden_state=sequence_output,
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class TFSegformerPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SegformerConfig
base_model_prefix = "segformer"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 512, 512), dtype=tf.float32)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
SEGFORMER_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
Parameters:
config ([`SegformerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
SEGFORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False``):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
@add_start_docstrings(
"The bare SegFormer encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
SEGFORMER_START_DOCSTRING,
)
class TFSegformerModel(TFSegformerPreTrainedModel):
def __init__(self, config: SegformerConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.config = config
# hierarchical Transformer encoder
self.segformer = TFSegformerMainLayer(config, name="segformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def call(
self,
pixel_values: tf.Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple, TFBaseModelOutput]:
outputs = self.segformer(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
@add_start_docstrings(
"""
SegFormer Model transformer with an image classification head on top (a linear layer on top of the final hidden
states) e.g. for ImageNet.
""",
SEGFORMER_START_DOCSTRING,
)
class TFSegformerForImageClassification(TFSegformerPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: SegformerConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.segformer = TFSegformerMainLayer(config, name="segformer")
# Classifier head
self.classifier = tf.keras.layers.Dense(config.num_labels, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TFSequenceClassifierOutput]:
outputs = self.segformer(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
# convert last hidden states to (batch_size, height*width, hidden_size)
batch_size = shape_list(sequence_output)[0]
sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1])
sequence_output = tf.reshape(sequence_output, (batch_size, -1, self.config.hidden_sizes[-1]))
# global average pooling
sequence_output = tf.reduce_mean(sequence_output, axis=1)
logits = self.classifier(sequence_output)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
class TFSegformerMLP(tf.keras.layers.Layer):
"""
Linear Embedding.
"""
def __init__(self, config: SegformerConfig, **kwargs):
super().__init__(**kwargs)
self.proj = tf.keras.layers.Dense(config.decoder_hidden_size, name="proj")
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
height = shape_list(hidden_states)[1]
width = shape_list(hidden_states)[2]
hidden_dim = shape_list(hidden_states)[-1]
hidden_states = tf.reshape(hidden_states, (-1, height * width, hidden_dim))
hidden_states = self.proj(hidden_states)
return hidden_states
class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
def __init__(self, config: SegformerConfig, **kwargs):
super().__init__(config, **kwargs)
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
mlps = []
for i in range(config.num_encoder_blocks):
mlp = TFSegformerMLP(config, name=f"linear_c.{i}")
mlps.append(mlp)
self.mlps = mlps
# the following 3 layers implement the ConvModule of the original implementation
self.linear_fuse = tf.keras.layers.Conv2D(
filters=config.decoder_hidden_size, kernel_size=1, use_bias=False, name="linear_fuse"
)
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="batch_norm")
self.activation = tf.keras.layers.Activation("relu")
self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob)
self.classifier = tf.keras.layers.Conv2D(filters=config.num_labels, kernel_size=1, name="classifier")
self.config = config
def call(self, encoder_hidden_states):
batch_size = shape_list(encoder_hidden_states[-1])[0]
all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
height = width = tf.cast(height, tf.int32)
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
# unify channel dimension
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
height = shape_list(encoder_hidden_state)[1]
width = shape_list(encoder_hidden_state)[2]
encoder_hidden_state = mlp(encoder_hidden_state)
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
# upsample
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
upsample_resolution = shape_list(temp_state)[1:-1]
encoder_hidden_state = tf.image.resize(encoder_hidden_state, size=upsample_resolution, method="bilinear")
all_hidden_states += (encoder_hidden_state,)
hidden_states = self.linear_fuse(tf.concat(all_hidden_states[::-1], axis=-1))
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.dropout(hidden_states)
# logits of shape (batch_size, height/4, width/4, num_labels)
logits = self.classifier(hidden_states)
return logits
@add_start_docstrings(
"""SegFormer Model transformer with an all-MLP decode head on top e.g. for ADE20k, CityScapes.""",
SEGFORMER_START_DOCSTRING,
)
class TFSegformerForSemanticSegmentation(TFSegformerPreTrainedModel):
def __init__(self, config: SegformerConfig, **kwargs):
super().__init__(config, **kwargs)
self.segformer = TFSegformerMainLayer(config, name="segformer")
self.decode_head = TFSegformerDecodeHead(config, name="decode_head")
def hf_compute_loss(self, logits, labels):
# upsample logits to the images' original size
# `labels` is of shape (batch_size, height, width)
label_interp_shape = shape_list(labels)[1:]
upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
# compute weighted loss
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
def masked_loss(real, pred):
unmasked_loss = loss_fct(real, pred)
mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
masked_loss = unmasked_loss * mask
# Reduction strategy in the similar spirit with
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
return tf.reshape(reduced_masked_loss, (1,))
return masked_loss(labels, upsampled_logits)
@unpack_inputs
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFSemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: tf.Tensor,
labels: Optional[tf.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TFSemanticSegmenterOutput]:
r"""
labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a (per-pixel) classification loss is computed
(Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import SegformerFeatureExtractor, TFSegformerForSemanticSegmentation
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
>>> model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs, training=False)
>>> # logits are of shape (batch_size, num_labels, height, width)
>>> logits = outputs.logits
>>> logits.shape
(1, 150, 128, 128)
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.segformer(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
return_dict=return_dict,
)
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
logits = self.decode_head(encoder_hidden_states)
loss = None
if labels is not None:
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.hf_compute_loss(logits=logits, labels=labels)
# make logits of shape (batch_size, num_labels, height, width) to
# keep them consistent across APIs
logits = tf.transpose(logits, perm=[0, 3, 1, 2])
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFSemanticSegmenterOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)

View File

@ -1980,6 +1980,44 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSegformerDecodeHead(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSegformerForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSegformerForSemanticSegmentation(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSegformerModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSegformerPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -18,7 +18,7 @@
import inspect
import unittest
from transformers import is_torch_available, is_vision_available
from transformers import SegformerConfig, is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
@ -31,7 +31,6 @@ if is_torch_available():
from transformers import (
MODEL_MAPPING,
SegformerConfig,
SegformerForImageClassification,
SegformerForSemanticSegmentation,
SegformerModel,

View File

@ -0,0 +1,540 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow SegFormer model. """
import inspect
import unittest
from typing import List, Tuple
import numpy as np
from transformers import SegformerConfig
from transformers.file_utils import is_tf_available, is_vision_available
from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, TFSegformerModel
from transformers.models.segformer.modeling_tf_segformer import TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import SegformerFeatureExtractor
class TFSegformerConfigTester(ConfigTester):
def create_and_test_config_common_properties(self):
config = self.config_class(**self.inputs_dict)
self.parent.assertTrue(hasattr(config, "hidden_sizes"))
self.parent.assertTrue(hasattr(config, "num_attention_heads"))
self.parent.assertTrue(hasattr(config, "num_encoder_blocks"))
class TFSegformerModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=64,
num_channels=3,
num_encoder_blocks=4,
depths=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
hidden_sizes=[16, 32, 64, 128],
downsampling_rates=[1, 4, 8, 16],
num_attention_heads=[1, 2, 4, 8],
is_training=True,
use_labels=True,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.num_encoder_blocks = num_encoder_blocks
self.sr_ratios = sr_ratios
self.depths = depths
self.hidden_sizes = hidden_sizes
self.downsampling_rates = downsampling_rates
self.num_attention_heads = num_attention_heads
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return SegformerConfig(
image_size=self.image_size,
num_channels=self.num_channels,
num_encoder_blocks=self.num_encoder_blocks,
depths=self.depths,
hidden_sizes=self.hidden_sizes,
num_attention_heads=self.num_attention_heads,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
num_labels=self.num_labels,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFSegformerModel(config=config)
result = model(pixel_values, training=False)
expected_height = expected_width = self.image_size // (self.downsampling_rates[-1] * 2)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.hidden_sizes[-1], expected_height, expected_width)
)
def create_and_check_for_image_segmentation(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = TFSegformerForSemanticSegmentation(config)
result = model(pixel_values, training=False)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
def prepare_config_and_inputs_for_keras_fit(self, for_segmentation: bool = False):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, seg_labels = config_and_inputs
if for_segmentation:
inputs_dict = {"pixel_values": pixel_values, "labels": seg_labels}
else:
inputs_dict = {"pixel_values": pixel_values, "labels": tf.zeros((self.batch_size))}
return config, inputs_dict
@require_tf
class TFSegformerModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TFSegformerModel, TFSegformerForImageClassification, TFSegformerForSemanticSegmentation)
if is_tf_available()
else ()
)
test_head_masking = False
test_onnx = False
test_pruning = False
test_resize_embeddings = False
def setUp(self):
self.model_tester = TFSegformerModelTester(self)
self.config_tester = TFSegformerConfigTester(self, config_class=SegformerConfig, has_text_modality=False)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip("SegFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip("SegFormer does not have get_input_embeddings method and get_output_embeddings methods")
def test_model_common_attributes(self):
pass
@unittest.skip("Test was written for TF 1.x and isn't really relevant here")
def test_compile_tf_model(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
expected_num_attentions = sum(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
# verify the first attentions (first block, first layer)
expected_seq_len = (self.model_tester.image_size // 4) ** 2
expected_reduced_seq_len = (self.model_tester.image_size // (4 * self.model_tester.sr_ratios[0])) ** 2
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads[0], expected_seq_len, expected_reduced_seq_len],
)
# verify the last attentions (last block, last layer)
expected_seq_len = (self.model_tester.image_size // 32) ** 2
expected_reduced_seq_len = (self.model_tester.image_size // (32 * self.model_tester.sr_ratios[-1])) ** 2
self.assertListEqual(
list(attentions[-1].shape[-3:]),
[self.model_tester.num_attention_heads[-1], expected_seq_len, expected_reduced_seq_len],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + 1, len(outputs))
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), expected_num_attentions)
# verify the first attentions (first block, first layer)
expected_seq_len = (self.model_tester.image_size // 4) ** 2
expected_reduced_seq_len = (self.model_tester.image_size // (4 * self.model_tester.sr_ratios[0])) ** 2
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads[0], expected_seq_len, expected_reduced_seq_len],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = self.model_tester.num_encoder_blocks
self.assertEqual(len(hidden_states), expected_num_layers)
# verify the first hidden states (first block)
self.assertListEqual(
list(hidden_states[0].shape[-3:]),
[
self.model_tester.hidden_sizes[0],
self.model_tester.image_size // 4,
self.model_tester.image_size // 4,
],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
if self.has_attentions:
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
# todo: incorporate label support for semantic segmentation in `test_modeling_tf_common.py`.
def test_dataset_conversion(self):
gpus = tf.config.list_physical_devices("GPU")
# Grouped convs aren't supported on CPUs for backprop.
if len(gpus) >= 1:
super().test_dataset_conversion()
def test_keras_fit(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
gpus = tf.config.list_physical_devices("GPU")
def apply(model):
if getattr(model, "hf_compute_loss", None):
model_weights = model.get_weights()
# Test that model correctly compute the loss with kwargs
for_segmentation = True if model_class.__name__ == "TFSegformerForSemanticSegmentation" else False
_, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit(
for_segmentation=for_segmentation
)
label_names = {"labels"}
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
# We reinitialize the model here even though our learning rate was zero
# because BatchNorm updates weights by means other than gradient descent.
model.set_weights(model_weights)
history2 = model.fit(
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
for model_class in self.all_model_classes:
# Since `TFSegformerModel` cannot operate with the default `fit()` method.
if model_class.__name__ != "TFSegformerModel":
# Grouped convs and backprop with them isn't supported on CPUs.
model = model_class(config)
if len(gpus) > 1:
apply(model)
def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def apply(model):
for_segmentation = True if model_class.__name__ == "TFSegformerForSemanticSegmentation" else False
# The number of elements in the loss should be the same as the number of elements in the label
_, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit(
for_segmentation=for_segmentation
)
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
loss_size = tf.size(added_label)
# Test that model correctly compute the loss with kwargs
possible_input_names = {"input_ids", "pixel_values", "input_features"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
loss = model(model_input, **prepared_for_class)[0]
if model_class.__name__ == "TFSegformerForSemanticSegmentation":
# Semantic segmentation loss is computed similarly as
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210.
self.assertEqual(loss.shape, (1,))
else:
self.assertEqual(loss.shape, [loss_size])
# Test that model correctly compute the loss with a dict
_, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit(
for_segmentation=for_segmentation
)
loss = model(**prepared_for_class)[0]
if model_class.__name__ == "TFSegformerForSemanticSegmentation":
self.assertEqual(loss.shape, (1,))
else:
self.assertEqual(loss.shape, [loss_size])
# Test that model correctly compute the loss with a tuple
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.signature(model.call).parameters
signature_names = list(signature.keys())
# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {0: input_name}
for label_key in label_keys:
label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with their default values, update the values and convert to a tuple
list_input = []
for name in signature_names:
if name != "kwargs":
list_input.append(signature[name].default)
for index, value in sorted_tuple_index_mapping:
list_input[index] = prepared_for_class[value]
tuple_input = tuple(list_input)
# Send to model
loss = model(tuple_input[:-1])[0]
if model_class.__name__ == "TFSegformerForSemanticSegmentation":
self.assertEqual(loss.shape, (1,))
else:
self.assertEqual(loss.shape, [loss_size])
for model_class in self.all_model_classes:
# Since `TFSegformerModel` won't have labels against which we
# could compute loss.
if model_class.__name__ != "TFSegformerModel":
model = model_class(config)
apply(model)
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
@slow
def test_model_from_pretrained(self):
for model_name in TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFSegformerModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
class TFSegformerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_image_segmentation_ade(self):
# only resize + normalize
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
model = TFSegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
image = prepare_img()
encoded_inputs = feature_extractor(images=image, return_tensors="tf")
pixel_values = encoded_inputs.pixel_values
outputs = model(pixel_values, training=False)
expected_shape = tf.TensorShape((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant(
[
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
]
)
tf.debugging.assert_near(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4)
@slow
def test_inference_image_segmentation_city(self):
# only resize + normalize
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
model = TFSegformerForSemanticSegmentation.from_pretrained(
"nvidia/segformer-b1-finetuned-cityscapes-1024-1024"
)
image = prepare_img()
encoded_inputs = feature_extractor(images=image, return_tensors="tf")
pixel_values = encoded_inputs.pixel_values
outputs = model(pixel_values, training=False)
expected_shape = tf.TensorShape((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant(
[
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
]
)
tf.debugging.assert_near(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1)

View File

@ -98,6 +98,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
"OPTDecoderWrapper",
"TFSegformerDecodeHead", # Not a regular model.
]
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
@ -137,6 +138,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"SegformerDecodeHead",
"TFSegformerDecodeHead",
"FlaxBeitForMaskedImageModeling",
"PLBartEncoder",
"PLBartDecoder",

View File

@ -64,6 +64,7 @@ src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
src/transformers/models/speech_to_text/modeling_speech_to_text.py
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
src/transformers/models/segformer/modeling_tf_segformer.py
src/transformers/models/swin/modeling_swin.py
src/transformers/models/trocr/modeling_trocr.py
src/transformers/models/unispeech/modeling_unispeech.py