mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix undefined variable
This commit is contained in:
parent
901d319b7e
commit
1eb0953755
@ -339,11 +339,10 @@ class XGLMAttention(nn.Module):
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
dtype_attn_weights = attn_weights.dtype
|
||||
|
||||
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
||||
if dtype_attn_weights == torch.float16:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
|
||||
if attn_weights.dtype == torch.float16:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(attn_weights.dtype)
|
||||
else:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user