[housekeeping] Upgrade # type Python 2 syntax

cc @sshleifer
This commit is contained in:
Julien Chaumond 2020-04-23 10:39:24 -04:00
parent cb3c2212c7
commit a946b6b51b
2 changed files with 7 additions and 7 deletions

View File

@ -52,8 +52,8 @@ class PretrainedConfig(object):
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
Is the model used with Torchscript (for PyTorch models).
"""
pretrained_config_archive_map = {} # type: Dict[str, str]
model_type = "" # type: str
pretrained_config_archive_map: Dict[str, str] = {}
model_type: str = ""
def __init__(self, **kwargs):
# Attributes with defaults

View File

@ -570,7 +570,7 @@ class SelfAttention(nn.Module):
need_weights=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel"""
static_kv = self.encoder_decoder_attention # type: bool
static_kv: bool = self.encoder_decoder_attention
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
@ -666,7 +666,7 @@ class SelfAttention(nn.Module):
assert v is not None
v = torch.cat([prev_value, v], dim=1)
assert k is not None and v is not None
prev_key_padding_mask = saved_state.get("prev_key_padding_mask", None) # type: Optional[Tensor]
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
key_padding_mask = self._cat_prev_key_padding_mask(
key_padding_mask, prev_key_padding_mask, bsz, k.size(1), static_kv
)
@ -798,7 +798,7 @@ class BartModel(PretrainedBartModel):
input_ids,
attention_mask=None,
decoder_input_ids=None,
encoder_outputs=None, # type: Tuple
encoder_outputs: Optional[Tuple] = None,
decoder_attention_mask=None,
decoder_cached_states=None,
use_cache=False,
@ -831,9 +831,9 @@ class BartModel(PretrainedBartModel):
use_cache=use_cache,
)
# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs)
assert isinstance(decoder_outputs[0], torch.Tensor)
encoder_outputs = _filter_out_falsey_values(encoder_outputs) # type: tuple
encoder_outputs: Tuple = _filter_out_falsey_values(encoder_outputs)
return decoder_outputs + encoder_outputs
def get_input_embeddings(self):