From 3b22bfbc6afbf7aa65ce0f255e3c75a0dd7524d3 Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Tue, 14 Mar 2023 16:55:18 +0300 Subject: [PATCH] Create MaskedImageCompletionOutput and fix ViT docs (#22152) * create MaskedImageCompletionOutput * fix bugs * fix bugs --- src/transformers/modeling_outputs.py | 28 +++++++++++++++++++++ src/transformers/models/vit/modeling_vit.py | 15 +++++++---- tests/models/vit/test_modeling_vit.py | 4 +-- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 4f7540d0ff9..a4d465b1d48 100755 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -1281,6 +1281,34 @@ class ImageSuperResolutionOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class MaskedImageCompletionOutput(ModelOutput): + """ + Base class for outputs of masked image completion / in-painting models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Reconstruction loss. + reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed / completed images. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + reconstruction: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class Wav2Vec2BaseModelOutput(ModelOutput): """ diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 449eda3ee82..154afdb211f 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -25,7 +25,12 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + MaskedImageCompletionOutput, +) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -643,7 +648,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): self.post_init() @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=MaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -653,7 +658,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, - ) -> Union[tuple, MaskedLMOutput]: + ) -> Union[tuple, MaskedImageCompletionOutput]: r""" bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). @@ -723,9 +728,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): output = (reconstructed_pixel_values,) + outputs[1:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return MaskedLMOutput( + return MaskedImageCompletionOutput( loss=masked_im_loss, - logits=reconstructed_pixel_values, + reconstruction=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index be509d460e6..d7ae3c162f5 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -134,7 +134,7 @@ class ViTModelTester: model.eval() result = model(pixel_values) self.parent.assertEqual( - result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) + result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) ) # test greyscale images @@ -145,7 +145,7 @@ class ViTModelTester: pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) result = model(pixel_values) - self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) + self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size)) def create_and_check_for_image_classification(self, config, pixel_values, labels): config.num_labels = self.type_sequence_label_size