diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index b600c2140ee..85585f3d5dd 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -15,7 +15,7 @@ import math -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -61,7 +61,13 @@ class BlipTextEmbeddings(nn.Module): self.config = config - def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() else: @@ -134,14 +140,14 @@ class BlipTextSelfAttention(nn.Module): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys @@ -263,7 +269,7 @@ class BlipTextAttention(nn.Module): encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, - ): + ) -> Tuple[torch.Tensor]: self_outputs = self.self( hidden_states, attention_mask, @@ -324,14 +330,14 @@ class BlipTextLayer(nn.Module): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( @@ -382,17 +388,17 @@ class BlipTextEncoder(nn.Module): def forward( self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: if self.gradient_checkpointing and self.training: if use_cache: logger.warn( @@ -663,21 +669,21 @@ class BlipTextModel(BlipTextPreTrainedModel): def forward( self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - is_decoder=False, - ): + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + is_decoder: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -819,23 +825,23 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel): def forward( self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - return_logits=False, - is_decoder=True, - reduction="mean", - ): + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_logits: Optional[bool] = False, + is_decoder: Optional[bool] = True, + reduction: Optional[str] = "mean", + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is