remove potential UndefinedError

This commit is contained in:
Rémi Louf 2019-10-17 17:52:32 +02:00
parent b915ba9dfe
commit cb26b035c6

View File

@ -81,8 +81,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
mask = alen < lengths[:, None]
# attention mask is the same as mask, or triangular inferior attention (causal)
bs = lengths.size(0)
if causal:
bs = lengths.size(0)
attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
else:
attn_mask = mask