mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fix matmul inputs dtype (#18585)
This commit is contained in:
parent
c99e984657
commit
86d0b26d6c
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch DeBERTa model."""
|
""" PyTorch DeBERTa model."""
|
||||||
|
|
||||||
import math
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
@ -640,8 +639,8 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
|
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
|
||||||
qkvb = [None] * 3
|
qkvb = [None] * 3
|
||||||
|
|
||||||
q = linear(qkvw[0], qkvb[0], query_states)
|
q = linear(qkvw[0], qkvb[0], torch.tensor(query_states, dtype=qkvw[0].dtype))
|
||||||
k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)]
|
k, v = [linear(qkvw[i], qkvb[i], torch.tensor(hidden_states, dtype=qkvw[i].dtype)) for i in range(1, 3)]
|
||||||
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
|
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
|
||||||
|
|
||||||
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
|
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
|
||||||
@ -650,8 +649,8 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
rel_att = None
|
rel_att = None
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
scale_factor = 1 + len(self.pos_att_type)
|
scale_factor = 1 + len(self.pos_att_type)
|
||||||
scale = math.sqrt(query_layer.size(-1) * scale_factor)
|
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||||
query_layer = query_layer / scale
|
query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
if self.relative_attention:
|
if self.relative_attention:
|
||||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||||
@ -711,13 +710,13 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
if "p2c" in self.pos_att_type:
|
if "p2c" in self.pos_att_type:
|
||||||
pos_query_layer = self.pos_q_proj(rel_embeddings)
|
pos_query_layer = self.pos_q_proj(rel_embeddings)
|
||||||
pos_query_layer = self.transpose_for_scores(pos_query_layer)
|
pos_query_layer = self.transpose_for_scores(pos_query_layer)
|
||||||
pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor)
|
pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||||
if query_layer.size(-2) != key_layer.size(-2):
|
if query_layer.size(-2) != key_layer.size(-2):
|
||||||
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
|
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
|
||||||
else:
|
else:
|
||||||
r_pos = relative_pos
|
r_pos = relative_pos
|
||||||
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
|
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
|
||||||
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
|
p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))
|
||||||
p2c_att = torch.gather(
|
p2c_att = torch.gather(
|
||||||
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
|
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
|
||||||
).transpose(-1, -2)
|
).transpose(-1, -2)
|
||||||
|
@ -717,7 +717,9 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
if "p2c" in self.pos_att_type:
|
if "p2c" in self.pos_att_type:
|
||||||
scale_factor += 1
|
scale_factor += 1
|
||||||
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
|
||||||
|
scale, dtype=query_layer.dtype
|
||||||
|
)
|
||||||
if self.relative_attention:
|
if self.relative_attention:
|
||||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||||
rel_att = self.disentangled_attention_bias(
|
rel_att = self.disentangled_attention_bias(
|
||||||
@ -799,7 +801,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
||||||
)
|
)
|
||||||
score += c2p_att / scale
|
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
|
||||||
|
|
||||||
# position->content
|
# position->content
|
||||||
if "p2c" in self.pos_att_type:
|
if "p2c" in self.pos_att_type:
|
||||||
@ -822,7 +824,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
||||||
).transpose(-1, -2)
|
).transpose(-1, -2)
|
||||||
score += p2c_att / scale
|
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
|
||||||
|
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
@ -791,7 +791,9 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
if "p2c" in self.pos_att_type:
|
if "p2c" in self.pos_att_type:
|
||||||
scale_factor += 1
|
scale_factor += 1
|
||||||
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
|
||||||
|
scale, dtype=query_layer.dtype
|
||||||
|
)
|
||||||
if self.relative_attention:
|
if self.relative_attention:
|
||||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||||
rel_att = self.disentangled_attention_bias(
|
rel_att = self.disentangled_attention_bias(
|
||||||
@ -873,7 +875,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
||||||
)
|
)
|
||||||
score += c2p_att / scale
|
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
|
||||||
|
|
||||||
# position->content
|
# position->content
|
||||||
if "p2c" in self.pos_att_type:
|
if "p2c" in self.pos_att_type:
|
||||||
@ -896,7 +898,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
||||||
).transpose(-1, -2)
|
).transpose(-1, -2)
|
||||||
score += p2c_att / scale
|
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
|
||||||
|
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user