mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-26 07:49:01 +06:00
remove potential UndefinedError
This commit is contained in:
parent
b915ba9dfe
commit
cb26b035c6
@ -81,8 +81,8 @@ def get_masks(slen, lengths, causal, padding_mask=None):
|
|||||||
mask = alen < lengths[:, None]
|
mask = alen < lengths[:, None]
|
||||||
|
|
||||||
# attention mask is the same as mask, or triangular inferior attention (causal)
|
# attention mask is the same as mask, or triangular inferior attention (causal)
|
||||||
|
bs = lengths.size(0)
|
||||||
if causal:
|
if causal:
|
||||||
bs = lengths.size(0)
|
|
||||||
attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
|
attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
|
||||||
else:
|
else:
|
||||||
attn_mask = mask
|
attn_mask = mask
|
||||||
|
Loading…
Reference in New Issue
Block a user