Fix matmul inputs dtype (#18585)

This commit is contained in:
Jingya HUANG 2022-08-17 15:59:43 +02:00 committed by GitHub
parent c99e984657
commit 86d0b26d6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 13 deletions

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" PyTorch DeBERTa model.""" """ PyTorch DeBERTa model."""
import math
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
@ -640,8 +639,8 @@ class DisentangledSelfAttention(nn.Module):
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
qkvb = [None] * 3 qkvb = [None] * 3
q = linear(qkvw[0], qkvb[0], query_states) q = linear(qkvw[0], qkvb[0], torch.tensor(query_states, dtype=qkvw[0].dtype))
k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] k, v = [linear(qkvw[i], qkvb[i], torch.tensor(hidden_states, dtype=qkvw[i].dtype)) for i in range(1, 3)]
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
@ -650,8 +649,8 @@ class DisentangledSelfAttention(nn.Module):
rel_att = None rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1 + len(self.pos_att_type) scale_factor = 1 + len(self.pos_att_type)
scale = math.sqrt(query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
query_layer = query_layer / scale query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
@ -711,13 +710,13 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings) pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer) pos_query_layer = self.transpose_for_scores(pos_query_layer)
pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if query_layer.size(-2) != key_layer.size(-2): if query_layer.size(-2) != key_layer.size(-2):
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
else: else:
r_pos = relative_pos r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))
p2c_att = torch.gather( p2c_att = torch.gather(
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
).transpose(-1, -2) ).transpose(-1, -2)

View File

@ -717,7 +717,9 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
scale, dtype=query_layer.dtype
)
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(
@ -799,7 +801,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
) )
score += c2p_att / scale score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
@ -822,7 +824,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / scale score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
return score return score

View File

@ -791,7 +791,9 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
scale, dtype=query_layer.dtype
)
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(
@ -873,7 +875,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
) )
score += c2p_att / scale score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
@ -896,7 +898,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1, dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2) ).transpose(-1, -2)
score += p2c_att / scale score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
return score return score