[Blip] Fix blip output name (#24889)

* fix blip output name

* add property

* oops

* fix failing test
This commit is contained in:
Younes Belkada 2023-07-18 19:30:27 +02:00 committed by GitHub
parent a9e067a45c
commit 5c5cb4eeb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 6 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch BLIP model."""
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
@ -74,7 +75,7 @@ class BlipForConditionalGenerationModelOutput(ModelOutput):
Args:
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
Languge modeling loss from the text decoder.
decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
Prediction scores of the language modeling head of the text decoder model.
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
The image embeddings obtained after applying the Vision Transformer model to the input image.
@ -94,12 +95,21 @@ class BlipForConditionalGenerationModelOutput(ModelOutput):
"""
loss: Optional[Tuple[torch.FloatTensor]] = None
decoder_logits: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@property
def decoder_logits(self):
warnings.warn(
"`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the `logits` attribute to retrieve the final output instead.",
FutureWarning,
)
return self.logits
@dataclass
class BlipTextVisionModelOutput(ModelOutput):
@ -1011,7 +1021,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel):
return BlipForConditionalGenerationModelOutput(
loss=outputs.loss,
decoder_logits=outputs.logits,
logits=outputs.logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
@ -84,7 +85,7 @@ class TFBlipForConditionalGenerationModelOutput(ModelOutput):
Args:
loss (`tf.Tensor`, *optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`):
Languge modeling loss from the text decoder.
decoder_logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
Prediction scores of the language modeling head of the text decoder model.
image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)`, *optional*):
The image embeddings obtained after applying the Vision Transformer model to the input image.
@ -104,12 +105,21 @@ class TFBlipForConditionalGenerationModelOutput(ModelOutput):
"""
loss: Tuple[tf.Tensor] | None = None
decoder_logits: Tuple[tf.Tensor] | None = None
logits: Tuple[tf.Tensor] | None = None
image_embeds: tf.Tensor | None = None
last_hidden_state: tf.Tensor = None
hidden_states: Tuple[tf.Tensor] | None = None
attentions: Tuple[tf.Tensor] | None = None
@property
def decoder_logits(self):
warnings.warn(
"`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the `logits` attribute to retrieve the final output instead.",
FutureWarning,
)
return self.logits
@dataclass
class TFBlipTextVisionModelOutput(ModelOutput):
@ -1078,7 +1088,7 @@ class TFBlipForConditionalGeneration(TFBlipPreTrainedModel):
return TFBlipForConditionalGenerationModelOutput(
loss=outputs.loss,
decoder_logits=outputs.logits,
logits=outputs.logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,