diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index a4d465b1d48..4f7540d0ff9 100755 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -1281,34 +1281,6 @@ class ImageSuperResolutionOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None -@dataclass -class MaskedImageCompletionOutput(ModelOutput): - """ - Base class for outputs of masked image completion / in-painting models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Reconstruction loss. - reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Reconstructed / completed images. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states - (also called feature maps) of the model at the output of each stage. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - reconstruction: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - - @dataclass class Wav2Vec2BaseModelOutput(ModelOutput): """ diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 154afdb211f..449eda3ee82 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -25,12 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, - ImageClassifierOutput, - MaskedImageCompletionOutput, -) +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -648,7 +643,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): self.post_init() @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -658,7 +653,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, MaskedImageCompletionOutput]: + ) -> Union[tuple, MaskedLMOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -728,9 +723,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return MaskedImageCompletionOutput( + return MaskedLMOutput( loss=masked_im_loss, - reconstruction=reconstructed_pixel_values, + logits=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index d7ae3c162f5..be509d460e6 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -134,7 +134,7 @@ class ViTModelTester: model.eval() result = model(pixel_values) self.parent.assertEqual( - result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) ) # test greyscale images @@ -145,7 +145,7 @@ class ViTModelTester: pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) - self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size)) + self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size