mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Standardize semantic segmentation models outputs (#15469)
* Standardize instance segmentation models outputs * Rename output * Update src/transformers/modeling_outputs.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add legacy argument to the config and model forward * Update src/transformers/models/beit/modeling_beit.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Copy fix in Segformer Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
31be2f45a9
commit
ac6aa10f23
@ -150,6 +150,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
||||
|
||||
[[autodoc]] AutoModelForImageSegmentation
|
||||
|
||||
## AutoModelForSemanticSegmentation
|
||||
|
||||
[[autodoc]] AutoModelForSemanticSegmentation
|
||||
|
||||
## TFAutoModel
|
||||
|
||||
[[autodoc]] TFAutoModel
|
||||
|
@ -688,6 +688,7 @@ if is_torch_available():
|
||||
"AutoModelForObjectDetection",
|
||||
"AutoModelForPreTraining",
|
||||
"AutoModelForQuestionAnswering",
|
||||
"AutoModelForSemanticSegmentation",
|
||||
"AutoModelForSeq2SeqLM",
|
||||
"AutoModelForSequenceClassification",
|
||||
"AutoModelForSpeechSeq2Seq",
|
||||
@ -2797,6 +2798,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSemanticSegmentation,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
|
@ -812,3 +812,32 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SemanticSegmentationModelOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of semantic segmentation models.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height, width)`):
|
||||
Classification scores for each pixel.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
||||
shape `(batch_size, patch_size, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
@ -43,6 +43,7 @@ if is_torch_available():
|
||||
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
|
||||
"MODEL_FOR_PRETRAINING_MAPPING",
|
||||
"MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
@ -65,6 +66,7 @@ if is_torch_available():
|
||||
"AutoModelForObjectDetection",
|
||||
"AutoModelForPreTraining",
|
||||
"AutoModelForQuestionAnswering",
|
||||
"AutoModelForSemanticSegmentation",
|
||||
"AutoModelForSeq2SeqLM",
|
||||
"AutoModelForSequenceClassification",
|
||||
"AutoModelForSpeechSeq2Seq",
|
||||
@ -155,6 +157,7 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
@ -177,6 +180,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForObjectDetection,
|
||||
AutoModelForPreTraining,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoModelForSemanticSegmentation,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
|
@ -278,11 +278,20 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Do not add new models here, this class will be deprecated in the future.
|
||||
# Model for Image Segmentation mapping
|
||||
("detr", "DetrForSegmentation"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Semantic Segmentation mapping
|
||||
("beit", "BeitForSemanticSegmentation"),
|
||||
("segformer", "SegformerForSemanticSegmentation"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
|
||||
@ -603,6 +612,9 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
||||
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
||||
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
|
||||
@ -745,6 +757,15 @@ class AutoModelForImageSegmentation(_BaseAutoModelClass):
|
||||
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
|
||||
|
||||
|
||||
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForSemanticSegmentation = auto_class_update(
|
||||
AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForObjectDetection(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||
|
||||
|
@ -93,6 +93,10 @@ class BeitConfig(PretrainedConfig):
|
||||
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
||||
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
||||
The index that is ignored by the loss function of the semantic segmentation model.
|
||||
legacy_output (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
|
||||
|
||||
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
|
||||
|
||||
Example:
|
||||
|
||||
@ -141,6 +145,7 @@ class BeitConfig(PretrainedConfig):
|
||||
auxiliary_num_convs=1,
|
||||
auxiliary_concat_input=False,
|
||||
semantic_loss_ignore_index=255,
|
||||
legacy_output=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -176,3 +181,4 @@ class BeitConfig(PretrainedConfig):
|
||||
self.auxiliary_num_convs = auxiliary_num_convs
|
||||
self.auxiliary_concat_input = auxiliary_concat_input
|
||||
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
||||
self.legacy_output = legacy_output
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
@ -31,7 +32,13 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
MaskedLMOutput,
|
||||
SemanticSegmentationModelOutput,
|
||||
SequenceClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from .configuration_beit import BeitConfig
|
||||
@ -1114,11 +1121,8 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def compute_loss(self, logits, auxiliary_logits, labels):
|
||||
def compute_loss(self, upsampled_logits, auxiliary_logits, labels):
|
||||
# upsample logits to the images' original size
|
||||
upsampled_logits = nn.functional.interpolate(
|
||||
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
if auxiliary_logits is not None:
|
||||
upsampled_auxiliary_logits = nn.functional.interpolate(
|
||||
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||
@ -1132,7 +1136,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
return loss
|
||||
|
||||
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
@ -1141,11 +1145,17 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
legacy_output=None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
||||
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
||||
legacy_output (`bool`, *optional*):
|
||||
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
|
||||
to `self.config.legacy_output`.
|
||||
|
||||
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -1164,13 +1174,21 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> # logits are of shape (batch_size, num_labels, height/4, width/4)
|
||||
>>> # logits are of shape (batch_size, num_labels, height, width)
|
||||
>>> logits = outputs.logits
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
|
||||
if not legacy_output:
|
||||
warnings.warn(
|
||||
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
|
||||
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
|
||||
"configuration of this model (only until v5, then that argument will be removed).",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
outputs = self.beit(
|
||||
pixel_values,
|
||||
@ -1197,6 +1215,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
features[i] = ops[i](features[i])
|
||||
|
||||
logits = self.decode_head(features)
|
||||
|
||||
upsampled_logits = nn.functional.interpolate(
|
||||
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
auxiliary_logits = None
|
||||
if self.auxiliary_head is not None:
|
||||
auxiliary_logits = self.auxiliary_head(features)
|
||||
@ -1206,18 +1229,26 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
||||
if self.config.num_labels == 1:
|
||||
raise ValueError("The number of labels should be greater than one")
|
||||
else:
|
||||
loss = self.compute_loss(logits, auxiliary_logits, labels)
|
||||
loss = self.compute_loss(upsampled_logits, auxiliary_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
output = (logits,) + outputs[2:]
|
||||
output = (logits if legacy_output else upsampled_logits,) + outputs[2:]
|
||||
else:
|
||||
output = (logits,) + outputs[3:]
|
||||
output = (logits if legacy_output else upsampled_logits,) + outputs[3:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
if legacy_output:
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
return SemanticSegmentationModelOutput(
|
||||
loss=loss,
|
||||
logits=upsampled_logits,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -83,6 +83,10 @@ class SegformerConfig(PretrainedConfig):
|
||||
required for the semantic segmentation model.
|
||||
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
||||
The index that is ignored by the loss function of the semantic segmentation model.
|
||||
legacy_output (`bool`, *optional*, defaults to `False`):
|
||||
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
|
||||
|
||||
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
|
||||
|
||||
Example:
|
||||
|
||||
@ -124,6 +128,7 @@ class SegformerConfig(PretrainedConfig):
|
||||
is_encoder_decoder=False,
|
||||
reshape_last_stage=True,
|
||||
semantic_loss_ignore_index=255,
|
||||
legacy_output=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@ -149,3 +154,4 @@ class SegformerConfig(PretrainedConfig):
|
||||
self.decoder_hidden_size = decoder_hidden_size
|
||||
self.reshape_last_stage = reshape_last_stage
|
||||
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
||||
self.legacy_output = legacy_output
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
import collections
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -30,7 +31,7 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
|
||||
from ...modeling_outputs import BaseModelOutput, SemanticSegmentationModelOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from .configuration_segformer import SegformerConfig
|
||||
@ -688,7 +689,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
@ -696,11 +697,17 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
legacy_output=None,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
||||
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
|
||||
legacy_output (`bool`, *optional*):
|
||||
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
|
||||
to `self.config.legacy_output`.
|
||||
|
||||
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -719,12 +726,20 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
||||
|
||||
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
|
||||
>>> logits = outputs.logits # shape (batch_size, num_labels, height, width)
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
|
||||
if not legacy_output:
|
||||
warnings.warn(
|
||||
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
|
||||
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
|
||||
"configuration of this model (only until v5, then that argument will be removed).",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
outputs = self.segformer(
|
||||
pixel_values,
|
||||
@ -737,28 +752,37 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
|
||||
|
||||
logits = self.decode_head(encoder_hidden_states)
|
||||
|
||||
upsampled_logits = nn.functional.interpolate(
|
||||
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.config.num_labels == 1:
|
||||
raise ValueError("The number of labels should be greater than one")
|
||||
else:
|
||||
# upsample logits to the images' original size
|
||||
upsampled_logits = nn.functional.interpolate(
|
||||
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
||||
loss = loss_fct(upsampled_logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
if output_hidden_states:
|
||||
output = (logits,) + outputs[1:]
|
||||
output = (logits if legacy_output else upsampled_logits,) + outputs[1:]
|
||||
else:
|
||||
output = (logits,) + outputs[2:]
|
||||
output = (logits if legacy_output else upsampled_logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
if legacy_output:
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
return SemanticSegmentationModelOutput(
|
||||
loss=loss,
|
||||
logits=upsampled_logits,
|
||||
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -474,6 +474,13 @@ class AutoModelForQuestionAnswering(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForSemanticSegmentation(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForSeq2SeqLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
@ -92,17 +92,20 @@ class BeitModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
self.out_indices = out_indices
|
||||
self.num_labels = num_labels
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
|
||||
labels = None
|
||||
pixel_labels = None
|
||||
if self.use_labels:
|
||||
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
pixel_labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values, labels
|
||||
return config, pixel_values, labels, pixel_labels
|
||||
|
||||
def get_config(self):
|
||||
return BeitConfig(
|
||||
@ -122,7 +125,7 @@ class BeitModelTester:
|
||||
out_indices=self.out_indices,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values, labels):
|
||||
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
|
||||
model = BeitModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -133,7 +136,7 @@ class BeitModelTester:
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
||||
|
||||
def create_and_check_for_masked_lm(self, config, pixel_values, labels):
|
||||
def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
|
||||
model = BeitForMaskedImageModeling(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -144,7 +147,7 @@ class BeitModelTester:
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
model = BeitForImageClassification(config)
|
||||
model.to(torch_device)
|
||||
@ -152,13 +155,23 @@ class BeitModelTester:
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
|
||||
|
||||
def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = BeitForSemanticSegmentation(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
|
||||
)
|
||||
result = model(pixel_values, labels=pixel_labels)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
pixel_values,
|
||||
labels,
|
||||
) = config_and_inputs
|
||||
config, pixel_values, labels, pixel_labels = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
@ -217,6 +230,10 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_image_segmentation(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
@ -516,14 +533,14 @@ class BeitModelIntegrationTest(unittest.TestCase):
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 150, 160, 160))
|
||||
expected_shape = torch.Size((1, 150, 640, 640))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
|
||||
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
|
||||
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
|
||||
[[-4.9225, -4.9225, -4.6066], [-4.9225, -4.9225, -4.6066], [-4.6675, -4.6675, -4.3617]],
|
||||
[[-5.8168, -5.8168, -5.5163], [-5.8168, -5.8168, -5.5163], [-5.5728, -5.5728, -5.2842]],
|
||||
[[-0.0078, -0.0078, 0.4926], [-0.0078, -0.0078, 0.4926], [0.3664, 0.3664, 0.8309]],
|
||||
]
|
||||
).to(torch_device)
|
||||
|
||||
|
@ -133,11 +133,11 @@ class SegformerModelTester:
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
|
||||
)
|
||||
result = model(pixel_values, labels=labels)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
|
||||
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@ -245,6 +245,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
list(attentions[-1].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads[-1], expected_seq_len, expected_reduced_seq_len],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
@ -255,7 +256,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
self.assertEqual(3, len(outputs))
|
||||
self.assertEqual(out_len + 1, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
@ -357,16 +358,17 @@ class SegformerModelIntegrationTest(unittest.TestCase):
|
||||
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
pixel_values = encoded_inputs.pixel_values.to(torch_device)
|
||||
|
||||
outputs = model(pixel_values)
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
|
||||
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
|
||||
expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
|
||||
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
|
||||
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
|
||||
[[-4.6309, -4.6309, -4.7425], [-4.6309, -4.6309, -4.7425], [-4.7011, -4.7011, -4.8136]],
|
||||
[[-12.1391, -12.1391, -12.2858], [-12.1391, -12.1391, -12.2858], [-12.2309, -12.2309, -12.3758]],
|
||||
[[-12.5134, -12.5134, -12.6328], [-12.5134, -12.5134, -12.6328], [-12.5576, -12.5576, -12.6865]],
|
||||
]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4))
|
||||
@ -385,16 +387,17 @@ class SegformerModelIntegrationTest(unittest.TestCase):
|
||||
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
pixel_values = encoded_inputs.pixel_values.to(torch_device)
|
||||
|
||||
outputs = model(pixel_values)
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
|
||||
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
|
||||
expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
|
||||
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
|
||||
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
|
||||
[[-13.5729, -13.5729, -13.6149], [-13.5729, -13.5729, -13.6149], [-13.6697, -13.6697, -13.7224]],
|
||||
[[-17.1638, -17.1638, -17.0022], [-17.1638, -17.1638, -17.0022], [-17.1754, -17.1754, -17.0358]],
|
||||
[[-3.6452, -3.6452, -3.5670], [-3.6452, -3.6452, -3.5670], [-3.5744, -3.5744, -3.5079]],
|
||||
]
|
||||
).to(torch_device)
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
|
||||
|
@ -118,8 +118,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"PerceiverForMultimodalAutoencoding",
|
||||
"PerceiverForOpticalFlow",
|
||||
"SegformerDecodeHead",
|
||||
"SegformerForSemanticSegmentation",
|
||||
"BeitForSemanticSegmentation",
|
||||
"FlaxBeitForMaskedImageModeling",
|
||||
"BeitForMaskedImageModeling",
|
||||
"CLIPTextModel",
|
||||
|
Loading…
Reference in New Issue
Block a user