From 35ecf99cc4e1046e5908f5933a4548e43a0f4f1d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 24 Feb 2022 15:52:52 +0100 Subject: [PATCH] Revert changes in logit size for semantic segmentation models (#15722) * Revert changes in logit size for semantic segmentation models * Address review comments --- src/transformers/modeling_outputs.py | 11 +++- .../models/beit/configuration_beit.py | 6 --- src/transformers/models/beit/modeling_beit.py | 50 +++++-------------- .../segformer/configuration_segformer.py | 6 --- .../models/segformer/modeling_segformer.py | 46 ++++------------- tests/beit/test_modeling_beit.py | 12 ++--- tests/segformer/test_modeling_segformer.py | 20 ++++---- 7 files changed, 50 insertions(+), 101 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index b0f3e01861f..b5979e23eb2 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -822,8 +822,17 @@ class SemanticSegmentationModelOutput(ModelOutput): Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height, width)`): + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): Classification scores for each pixel. + + + + The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is + to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the + original image size as post-processing. You should always check your logits shape and resize as needed. + + + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index a611e1b453c..9a1dfa8c20f 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -93,10 +93,6 @@ class BeitConfig(PretrainedConfig): Whether to concatenate the output of the auxiliary head with the input before the classification layer. semantic_loss_ignore_index (`int`, *optional*, defaults to 255): The index that is ignored by the loss function of the semantic segmentation model. - legacy_output (`bool`, *optional*, defaults to `False`): - Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`) - - This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers. Example: @@ -145,7 +141,6 @@ class BeitConfig(PretrainedConfig): auxiliary_num_convs=1, auxiliary_concat_input=False, semantic_loss_ignore_index=255, - legacy_output=False, **kwargs ): super().__init__(**kwargs) @@ -181,4 +176,3 @@ class BeitConfig(PretrainedConfig): self.auxiliary_num_convs = auxiliary_num_convs self.auxiliary_concat_input = auxiliary_concat_input self.semantic_loss_ignore_index = semantic_loss_ignore_index - self.legacy_output = legacy_output diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 9e9c3393a9d..69c213b8dd4 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -17,7 +17,6 @@ import collections.abc import math -import warnings from dataclasses import dataclass import torch @@ -1121,8 +1120,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): # Initialize weights and apply final processing self.post_init() - def compute_loss(self, upsampled_logits, auxiliary_logits, labels): + def compute_loss(self, logits, auxiliary_logits, labels): # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) if auxiliary_logits is not None: upsampled_auxiliary_logits = nn.functional.interpolate( auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False @@ -1145,17 +1147,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): output_attentions=None, output_hidden_states=None, return_dict=None, - legacy_output=None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). - legacy_output (`bool`, *optional*): - Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default - to `self.config.legacy_output`. - - This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers. Returns: @@ -1181,14 +1177,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output - if not legacy_output: - warnings.warn( - "The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. " - "You can activate the previous behavior by passing `legacy_output=True` to this call or the " - "configuration of this model (only until v5, then that argument will be removed).", - FutureWarning, - ) outputs = self.beit( pixel_values, @@ -1216,10 +1204,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): logits = self.decode_head(features) - upsampled_logits = nn.functional.interpolate( - logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False - ) - auxiliary_logits = None if self.auxiliary_head is not None: auxiliary_logits = self.auxiliary_head(features) @@ -1229,26 +1213,18 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): if self.config.num_labels == 1: raise ValueError("The number of labels should be greater than one") else: - loss = self.compute_loss(upsampled_logits, auxiliary_logits, labels) + loss = self.compute_loss(logits, auxiliary_logits, labels) if not return_dict: if output_hidden_states: - output = (logits if legacy_output else upsampled_logits,) + outputs[2:] + output = (logits,) + outputs[2:] else: - output = (logits if legacy_output else upsampled_logits,) + outputs[3:] + output = (logits,) + outputs[3:] return ((loss,) + output) if loss is not None else output - if legacy_output: - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) - else: - return SemanticSegmentationModelOutput( - loss=loss, - logits=upsampled_logits, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) + return SemanticSegmentationModelOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/segformer/configuration_segformer.py b/src/transformers/models/segformer/configuration_segformer.py index 0e0b3a9914e..d1790634e6e 100644 --- a/src/transformers/models/segformer/configuration_segformer.py +++ b/src/transformers/models/segformer/configuration_segformer.py @@ -83,10 +83,6 @@ class SegformerConfig(PretrainedConfig): required for the semantic segmentation model. semantic_loss_ignore_index (`int`, *optional*, defaults to 255): The index that is ignored by the loss function of the semantic segmentation model. - legacy_output (`bool`, *optional*, defaults to `False`): - Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`) - - This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers. Example: @@ -128,7 +124,6 @@ class SegformerConfig(PretrainedConfig): is_encoder_decoder=False, reshape_last_stage=True, semantic_loss_ignore_index=255, - legacy_output=False, **kwargs ): super().__init__(**kwargs) @@ -154,4 +149,3 @@ class SegformerConfig(PretrainedConfig): self.decoder_hidden_size = decoder_hidden_size self.reshape_last_stage = reshape_last_stage self.semantic_loss_ignore_index = semantic_loss_ignore_index - self.legacy_output = legacy_output diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 196d741ce70..7fa9ccade80 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -17,7 +17,6 @@ import collections import math -import warnings import torch import torch.utils.checkpoint @@ -697,17 +696,11 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): output_attentions=None, output_hidden_states=None, return_dict=None, - legacy_output=None, ): r""" labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). - legacy_output (`bool`, *optional*): - Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default - to `self.config.legacy_output`. - - This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers. Returns: @@ -732,14 +725,6 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output - if not legacy_output: - warnings.warn( - "The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. " - "You can activate the previous behavior by passing `legacy_output=True` to this call or the " - "configuration of this model (only until v5, then that argument will be removed).", - FutureWarning, - ) outputs = self.segformer( pixel_values, @@ -752,37 +737,28 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): logits = self.decode_head(encoder_hidden_states) - upsampled_logits = nn.functional.interpolate( - logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False - ) - loss = None if labels is not None: if self.config.num_labels == 1: raise ValueError("The number of labels should be greater than one") else: # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) loss = loss_fct(upsampled_logits, labels) if not return_dict: if output_hidden_states: - output = (logits if legacy_output else upsampled_logits,) + outputs[1:] + output = (logits,) + outputs[1:] else: - output = (logits if legacy_output else upsampled_logits,) + outputs[2:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - if legacy_output: - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) - else: - return SemanticSegmentationModelOutput( - loss=loss, - logits=upsampled_logits, - hidden_states=outputs.hidden_states if output_hidden_states else None, - attentions=outputs.attentions, - ) + return SemanticSegmentationModelOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) diff --git a/tests/beit/test_modeling_beit.py b/tests/beit/test_modeling_beit.py index 3d929a29999..4dce642eba4 100644 --- a/tests/beit/test_modeling_beit.py +++ b/tests/beit/test_modeling_beit.py @@ -162,11 +162,11 @@ class BeitModelTester: model.eval() result = model(pixel_values) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size) + result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2) ) result = model(pixel_values, labels=pixel_labels) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size) + result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2) ) def prepare_config_and_inputs_for_common(self): @@ -533,14 +533,14 @@ class BeitModelIntegrationTest(unittest.TestCase): logits = outputs.logits # verify the logits - expected_shape = torch.Size((1, 150, 640, 640)) + expected_shape = torch.Size((1, 150, 160, 160)) self.assertEqual(logits.shape, expected_shape) expected_slice = torch.tensor( [ - [[-4.9225, -4.9225, -4.6066], [-4.9225, -4.9225, -4.6066], [-4.6675, -4.6675, -4.3617]], - [[-5.8168, -5.8168, -5.5163], [-5.8168, -5.8168, -5.5163], [-5.5728, -5.5728, -5.2842]], - [[-0.0078, -0.0078, 0.4926], [-0.0078, -0.0078, 0.4926], [0.3664, 0.3664, 0.8309]], + [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], + [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], + [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], ] ).to(torch_device) diff --git a/tests/segformer/test_modeling_segformer.py b/tests/segformer/test_modeling_segformer.py index b7321dd690f..8798a823ea3 100644 --- a/tests/segformer/test_modeling_segformer.py +++ b/tests/segformer/test_modeling_segformer.py @@ -135,11 +135,11 @@ class SegformerModelTester: model.eval() result = model(pixel_values) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size) + result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4) ) result = model(pixel_values, labels=labels) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size) + result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4) ) def prepare_config_and_inputs_for_common(self): @@ -363,14 +363,14 @@ class SegformerModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(pixel_values) - expected_shape = torch.Size((1, model.config.num_labels, 512, 512)) + expected_shape = torch.Size((1, model.config.num_labels, 128, 128)) self.assertEqual(outputs.logits.shape, expected_shape) expected_slice = torch.tensor( [ - [[-4.6309, -4.6309, -4.7425], [-4.6309, -4.6309, -4.7425], [-4.7011, -4.7011, -4.8136]], - [[-12.1391, -12.1391, -12.2858], [-12.1391, -12.1391, -12.2858], [-12.2309, -12.2309, -12.3758]], - [[-12.5134, -12.5134, -12.6328], [-12.5134, -12.5134, -12.6328], [-12.5576, -12.5576, -12.6865]], + [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]], + [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]], + [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]], ] ).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4)) @@ -392,14 +392,14 @@ class SegformerModelIntegrationTest(unittest.TestCase): with torch.no_grad(): outputs = model(pixel_values) - expected_shape = torch.Size((1, model.config.num_labels, 512, 512)) + expected_shape = torch.Size((1, model.config.num_labels, 128, 128)) self.assertEqual(outputs.logits.shape, expected_shape) expected_slice = torch.tensor( [ - [[-13.5729, -13.5729, -13.6149], [-13.5729, -13.5729, -13.6149], [-13.6697, -13.6697, -13.7224]], - [[-17.1638, -17.1638, -17.0022], [-17.1638, -17.1638, -17.0022], [-17.1754, -17.1754, -17.0358]], - [[-3.6452, -3.6452, -3.5670], [-3.6452, -3.6452, -3.5670], [-3.5744, -3.5744, -3.5079]], + [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]], + [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]], + [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]], ] ).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))