diff --git a/transformers/modeling_xlm.py b/transformers/modeling_xlm.py index f1df6f668fb..166b98de630 100644 --- a/transformers/modeling_xlm.py +++ b/transformers/modeling_xlm.py @@ -81,8 +81,8 @@ def get_masks(slen, lengths, causal, padding_mask=None): mask = alen < lengths[:, None] # attention mask is the same as mask, or triangular inferior attention (causal) + bs = lengths.size(0) if causal: - bs = lengths.size(0) attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None] else: attn_mask = mask