mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Create MaskedImageCompletionOutput and fix ViT docs (#22152)
* create MaskedImageCompletionOutput * fix bugs * fix bugs
This commit is contained in:
parent
b45192ec47
commit
3b22bfbc6a
@ -1281,6 +1281,34 @@ class ImageSuperResolutionOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskedImageCompletionOutput(ModelOutput):
|
||||
"""
|
||||
Base class for outputs of masked image completion / in-painting models.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Reconstruction loss.
|
||||
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Reconstructed / completed images.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
|
||||
(also called feature maps) of the model at the output of each stage.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
reconstruction: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2BaseModelOutput(ModelOutput):
|
||||
"""
|
||||
|
@ -25,7 +25,12 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPooling,
|
||||
ImageClassifierOutput,
|
||||
MaskedImageCompletionOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
@ -643,7 +648,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=MaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
@ -653,7 +658,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
interpolate_pos_encoding: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[tuple, MaskedLMOutput]:
|
||||
) -> Union[tuple, MaskedImageCompletionOutput]:
|
||||
r"""
|
||||
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
||||
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
||||
@ -723,9 +728,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
||||
output = (reconstructed_pixel_values,) + outputs[1:]
|
||||
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
||||
|
||||
return MaskedLMOutput(
|
||||
return MaskedImageCompletionOutput(
|
||||
loss=masked_im_loss,
|
||||
logits=reconstructed_pixel_values,
|
||||
reconstruction=reconstructed_pixel_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
@ -134,7 +134,7 @@ class ViTModelTester:
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
|
||||
)
|
||||
|
||||
# test greyscale images
|
||||
@ -145,7 +145,7 @@ class ViTModelTester:
|
||||
|
||||
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
|
Loading…
Reference in New Issue
Block a user