diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index df85a307aae..0be1b3ed255 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -193,7 +193,7 @@ class AttentionMaskConverter: expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - inverted_mask = 1.0 - expanded_mask + inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)