mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
added type hints for blip_text pytorch model (#23071)
* added type hints for blip_text pytorch model * updated type hints for blip_text pytorch model
This commit is contained in:
parent
b8648290d2
commit
85e3d7b6a0
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@ -61,7 +61,13 @@ class BlipTextEmbeddings(nn.Module):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
past_key_values_length: int = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
@ -134,14 +140,14 @@ class BlipTextSelfAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[torch.Tensor]:
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
@ -263,7 +269,7 @@ class BlipTextAttention(nn.Module):
|
|||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[torch.Tensor]:
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@ -324,14 +330,14 @@ class BlipTextLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_value=None,
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
) -> Tuple[torch.Tensor]:
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
@ -382,17 +388,17 @@ class BlipTextEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=False,
|
output_attentions: Optional[bool] = False,
|
||||||
output_hidden_states=False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict=True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
@ -663,21 +669,21 @@ class BlipTextModel(BlipTextPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
encoder_embeds=None,
|
encoder_embeds: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
is_decoder=False,
|
is_decoder: Optional[bool] = False,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
||||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||||
@ -819,23 +825,23 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids=None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
labels=None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[List[torch.Tensor]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
return_logits=False,
|
return_logits: Optional[bool] = False,
|
||||||
is_decoder=True,
|
is_decoder: Optional[bool] = True,
|
||||||
reduction="mean",
|
reduction: Optional[str] = "mean",
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of
|
encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of
|
||||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is
|
||||||
|
Loading…
Reference in New Issue
Block a user