Fix undefined variable

This commit is contained in:
ydshieh 2022-09-28 20:59:21 +02:00
parent 901d319b7e
commit 1eb0953755

View File

@ -339,11 +339,10 @@ class XGLMAttention(nn.Module):
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
dtype_attn_weights = attn_weights.dtype
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if dtype_attn_weights == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(dtype_attn_weights)
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(attn_weights.dtype)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)