mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-27 08:18:58 +06:00
[housekeeping] Upgrade # type
Python 2 syntax
cc @sshleifer
This commit is contained in:
parent
cb3c2212c7
commit
a946b6b51b
@ -52,8 +52,8 @@ class PretrainedConfig(object):
|
|||||||
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Is the model used with Torchscript (for PyTorch models).
|
Is the model used with Torchscript (for PyTorch models).
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = {} # type: Dict[str, str]
|
pretrained_config_archive_map: Dict[str, str] = {}
|
||||||
model_type = "" # type: str
|
model_type: str = ""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# Attributes with defaults
|
# Attributes with defaults
|
||||||
|
@ -570,7 +570,7 @@ class SelfAttention(nn.Module):
|
|||||||
need_weights=False,
|
need_weights=False,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||||
static_kv = self.encoder_decoder_attention # type: bool
|
static_kv: bool = self.encoder_decoder_attention
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
assert embed_dim == self.embed_dim
|
assert embed_dim == self.embed_dim
|
||||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||||
@ -666,7 +666,7 @@ class SelfAttention(nn.Module):
|
|||||||
assert v is not None
|
assert v is not None
|
||||||
v = torch.cat([prev_value, v], dim=1)
|
v = torch.cat([prev_value, v], dim=1)
|
||||||
assert k is not None and v is not None
|
assert k is not None and v is not None
|
||||||
prev_key_padding_mask = saved_state.get("prev_key_padding_mask", None) # type: Optional[Tensor]
|
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
|
||||||
key_padding_mask = self._cat_prev_key_padding_mask(
|
key_padding_mask = self._cat_prev_key_padding_mask(
|
||||||
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
|
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
|
||||||
)
|
)
|
||||||
@ -798,7 +798,7 @@ class BartModel(PretrainedBartModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
encoder_outputs=None, # type: Tuple
|
encoder_outputs: Optional[Tuple] = None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
decoder_cached_states=None,
|
decoder_cached_states=None,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
@ -831,9 +831,9 @@ class BartModel(PretrainedBartModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
# Attention and hidden_states will be [] or None if they aren't needed
|
# Attention and hidden_states will be [] or None if they aren't needed
|
||||||
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
|
decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs)
|
||||||
assert isinstance(decoder_outputs[0], torch.Tensor)
|
assert isinstance(decoder_outputs[0], torch.Tensor)
|
||||||
encoder_outputs = _filter_out_falsey_values(encoder_outputs) # type: tuple
|
encoder_outputs: Tuple = _filter_out_falsey_values(encoder_outputs)
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user