Revert changes in logit size for semantic segmentation models (#15722)

* Revert changes in logit size for semantic segmentation models

* Address review comments
This commit is contained in:
Sylvain Gugger 2022-02-24 15:52:52 +01:00 committed by GitHub
parent d1fcc90abf
commit 35ecf99cc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 50 additions and 101 deletions

View File

@ -822,8 +822,17 @@ class SemanticSegmentationModelOutput(ModelOutput):
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height, width)`):
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Classification scores for each pixel.
<Tip warning={true}>
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.
</Tip>
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, patch_size, hidden_size)`.

View File

@ -93,10 +93,6 @@ class BeitConfig(PretrainedConfig):
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
The index that is ignored by the loss function of the semantic segmentation model.
legacy_output (`bool`, *optional*, defaults to `False`):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Example:
@ -145,7 +141,6 @@ class BeitConfig(PretrainedConfig):
auxiliary_num_convs=1,
auxiliary_concat_input=False,
semantic_loss_ignore_index=255,
legacy_output=False,
**kwargs
):
super().__init__(**kwargs)
@ -181,4 +176,3 @@ class BeitConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input
self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.legacy_output = legacy_output

View File

@ -17,7 +17,6 @@
import collections.abc
import math
import warnings
from dataclasses import dataclass
import torch
@ -1121,8 +1120,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def compute_loss(self, upsampled_logits, auxiliary_logits, labels):
def compute_loss(self, logits, auxiliary_logits, labels):
# upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
if auxiliary_logits is not None:
upsampled_auxiliary_logits = nn.functional.interpolate(
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
@ -1145,17 +1147,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
legacy_output=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
legacy_output (`bool`, *optional*):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
to `self.config.legacy_output`.
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Returns:
@ -1181,14 +1177,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
if not legacy_output:
warnings.warn(
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
"configuration of this model (only until v5, then that argument will be removed).",
FutureWarning,
)
outputs = self.beit(
pixel_values,
@ -1216,10 +1204,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
logits = self.decode_head(features)
upsampled_logits = nn.functional.interpolate(
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
)
auxiliary_logits = None
if self.auxiliary_head is not None:
auxiliary_logits = self.auxiliary_head(features)
@ -1229,26 +1213,18 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.compute_loss(upsampled_logits, auxiliary_logits, labels)
loss = self.compute_loss(logits, auxiliary_logits, labels)
if not return_dict:
if output_hidden_states:
output = (logits if legacy_output else upsampled_logits,) + outputs[2:]
output = (logits,) + outputs[2:]
else:
output = (logits if legacy_output else upsampled_logits,) + outputs[3:]
output = (logits,) + outputs[3:]
return ((loss,) + output) if loss is not None else output
if legacy_output:
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
else:
return SemanticSegmentationModelOutput(
loss=loss,
logits=upsampled_logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
return SemanticSegmentationModelOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)

View File

@ -83,10 +83,6 @@ class SegformerConfig(PretrainedConfig):
required for the semantic segmentation model.
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
The index that is ignored by the loss function of the semantic segmentation model.
legacy_output (`bool`, *optional*, defaults to `False`):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Example:
@ -128,7 +124,6 @@ class SegformerConfig(PretrainedConfig):
is_encoder_decoder=False,
reshape_last_stage=True,
semantic_loss_ignore_index=255,
legacy_output=False,
**kwargs
):
super().__init__(**kwargs)
@ -154,4 +149,3 @@ class SegformerConfig(PretrainedConfig):
self.decoder_hidden_size = decoder_hidden_size
self.reshape_last_stage = reshape_last_stage
self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.legacy_output = legacy_output

View File

@ -17,7 +17,6 @@
import collections
import math
import warnings
import torch
import torch.utils.checkpoint
@ -697,17 +696,11 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
output_attentions=None,
output_hidden_states=None,
return_dict=None,
legacy_output=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
legacy_output (`bool`, *optional*):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
to `self.config.legacy_output`.
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Returns:
@ -732,14 +725,6 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
if not legacy_output:
warnings.warn(
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
"configuration of this model (only until v5, then that argument will be removed).",
FutureWarning,
)
outputs = self.segformer(
pixel_values,
@ -752,37 +737,28 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
logits = self.decode_head(encoder_hidden_states)
upsampled_logits = nn.functional.interpolate(
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
)
loss = None
if labels is not None:
if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
else:
# upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels)
if not return_dict:
if output_hidden_states:
output = (logits if legacy_output else upsampled_logits,) + outputs[1:]
output = (logits,) + outputs[1:]
else:
output = (logits if legacy_output else upsampled_logits,) + outputs[2:]
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
if legacy_output:
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
else:
return SemanticSegmentationModelOutput(
loss=loss,
logits=upsampled_logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
return SemanticSegmentationModelOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)

View File

@ -162,11 +162,11 @@ class BeitModelTester:
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
)
result = model(pixel_values, labels=pixel_labels)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
)
def prepare_config_and_inputs_for_common(self):
@ -533,14 +533,14 @@ class BeitModelIntegrationTest(unittest.TestCase):
logits = outputs.logits
# verify the logits
expected_shape = torch.Size((1, 150, 640, 640))
expected_shape = torch.Size((1, 150, 160, 160))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[[-4.9225, -4.9225, -4.6066], [-4.9225, -4.9225, -4.6066], [-4.6675, -4.6675, -4.3617]],
[[-5.8168, -5.8168, -5.5163], [-5.8168, -5.8168, -5.5163], [-5.5728, -5.5728, -5.2842]],
[[-0.0078, -0.0078, 0.4926], [-0.0078, -0.0078, 0.4926], [0.3664, 0.3664, 0.8309]],
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
).to(torch_device)

View File

@ -135,11 +135,11 @@ class SegformerModelTester:
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
)
result = model(pixel_values, labels=labels)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
)
def prepare_config_and_inputs_for_common(self):
@ -363,14 +363,14 @@ class SegformerModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(pixel_values)
expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[[-4.6309, -4.6309, -4.7425], [-4.6309, -4.6309, -4.7425], [-4.7011, -4.7011, -4.8136]],
[[-12.1391, -12.1391, -12.2858], [-12.1391, -12.1391, -12.2858], [-12.2309, -12.2309, -12.3758]],
[[-12.5134, -12.5134, -12.6328], [-12.5134, -12.5134, -12.6328], [-12.5576, -12.5576, -12.6865]],
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4))
@ -392,14 +392,14 @@ class SegformerModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
outputs = model(pixel_values)
expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[[-13.5729, -13.5729, -13.6149], [-13.5729, -13.5729, -13.6149], [-13.6697, -13.6697, -13.7224]],
[[-17.1638, -17.1638, -17.0022], [-17.1638, -17.1638, -17.0022], [-17.1754, -17.1754, -17.0358]],
[[-3.6452, -3.6452, -3.5670], [-3.6452, -3.6452, -3.5670], [-3.5744, -3.5744, -3.5079]],
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))