mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-26 07:49:01 +06:00
[housekeeping] Upgrade # type
Python 2 syntax
cc @sshleifer
This commit is contained in:
parent
cb3c2212c7
commit
a946b6b51b
@ -52,8 +52,8 @@ class PretrainedConfig(object):
|
||||
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Is the model used with Torchscript (for PyTorch models).
|
||||
"""
|
||||
pretrained_config_archive_map = {} # type: Dict[str, str]
|
||||
model_type = "" # type: str
|
||||
pretrained_config_archive_map: Dict[str, str] = {}
|
||||
model_type: str = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Attributes with defaults
|
||||
|
@ -570,7 +570,7 @@ class SelfAttention(nn.Module):
|
||||
need_weights=False,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||
static_kv = self.encoder_decoder_attention # type: bool
|
||||
static_kv: bool = self.encoder_decoder_attention
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||
@ -666,7 +666,7 @@ class SelfAttention(nn.Module):
|
||||
assert v is not None
|
||||
v = torch.cat([prev_value, v], dim=1)
|
||||
assert k is not None and v is not None
|
||||
prev_key_padding_mask = saved_state.get("prev_key_padding_mask", None) # type: Optional[Tensor]
|
||||
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
|
||||
key_padding_mask = self._cat_prev_key_padding_mask(
|
||||
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
|
||||
)
|
||||
@ -798,7 +798,7 @@ class BartModel(PretrainedBartModel):
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
decoder_input_ids=None,
|
||||
encoder_outputs=None, # type: Tuple
|
||||
encoder_outputs: Optional[Tuple] = None,
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
use_cache=False,
|
||||
@ -831,9 +831,9 @@ class BartModel(PretrainedBartModel):
|
||||
use_cache=use_cache,
|
||||
)
|
||||
# Attention and hidden_states will be [] or None if they aren't needed
|
||||
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
|
||||
decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs)
|
||||
assert isinstance(decoder_outputs[0], torch.Tensor)
|
||||
encoder_outputs = _filter_out_falsey_values(encoder_outputs) # type: tuple
|
||||
encoder_outputs: Tuple = _filter_out_falsey_values(encoder_outputs)
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
Loading…
Reference in New Issue
Block a user