mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Deberta v2 code simplification (#15732)
* Removed spurious substraction * Fixed condition checking for attention type * Fixed sew_d copy of DeBERTa v2 attention * Removed unused `p2p` attention type from DebertaV2-class models * Fixed docs style
This commit is contained in:
parent
0a5ef036e6
commit
319cbbe191
@ -77,8 +77,8 @@ class DebertaV2Config(PretrainedConfig):
|
||||
position_biased_input (`bool`, *optional*, defaults to `False`):
|
||||
Whether add absolute position embedding to content embedding.
|
||||
pos_att_type (`List[str]`, *optional*):
|
||||
The type of relative position attention, it can be a combination of `["p2c", "c2p", "p2p"]`, e.g.
|
||||
`["p2c"]`, `["p2c", "c2p"]`, `["p2c", "c2p", 'p2p"]`.
|
||||
The type of relative position attention, it can be a combination of `["p2c", "c2p"]`, e.g. `["p2c"]`,
|
||||
`["p2c", "c2p"]`, `["p2c", "c2p"]`.
|
||||
layer_norm_eps (`float`, optional, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
"""
|
||||
|
@ -629,9 +629,9 @@ class DisentangledSelfAttention(nn.Module):
|
||||
self.pos_dropout = StableDropout(config.hidden_dropout_prob)
|
||||
|
||||
if not self.share_att_key:
|
||||
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "c2p" in self.pos_att_type:
|
||||
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = StableDropout(config.attention_probs_dropout_prob)
|
||||
@ -692,8 +692,6 @@ class DisentangledSelfAttention(nn.Module):
|
||||
scale_factor += 1
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
if "p2p" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
scale = math.sqrt(query_layer.size(-1) * scale_factor)
|
||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
||||
if self.relative_attention:
|
||||
@ -744,7 +742,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
att_span = self.pos_ebd_size
|
||||
relative_pos = relative_pos.long().to(query_layer.device)
|
||||
|
||||
rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze(0)
|
||||
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
||||
if self.share_att_key:
|
||||
pos_query_layer = self.transpose_for_scores(
|
||||
self.query_proj(rel_embeddings), self.num_attention_heads
|
||||
@ -753,13 +751,13 @@ class DisentangledSelfAttention(nn.Module):
|
||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
||||
)
|
||||
else:
|
||||
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "c2p" in self.pos_att_type:
|
||||
pos_key_layer = self.transpose_for_scores(
|
||||
self.pos_key_proj(rel_embeddings), self.num_attention_heads
|
||||
).repeat(
|
||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
||||
) # .split(self.all_head_size, dim=-1)
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
pos_query_layer = self.transpose_for_scores(
|
||||
self.pos_query_proj(rel_embeddings), self.num_attention_heads
|
||||
).repeat(
|
||||
@ -780,7 +778,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
score += c2p_att / scale
|
||||
|
||||
# position->content
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
|
||||
if key_layer.size(-2) != query_layer.size(-2):
|
||||
r_pos = build_relative_position(
|
||||
@ -794,8 +792,6 @@ class DisentangledSelfAttention(nn.Module):
|
||||
r_pos = relative_pos
|
||||
|
||||
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
|
||||
|
||||
if "p2c" in self.pos_att_type:
|
||||
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
|
||||
p2c_att = torch.gather(
|
||||
p2c_att,
|
||||
@ -804,20 +800,6 @@ class DisentangledSelfAttention(nn.Module):
|
||||
).transpose(-1, -2)
|
||||
score += p2c_att / scale
|
||||
|
||||
# position->position
|
||||
if "p2p" in self.pos_att_type:
|
||||
pos_query = pos_query_layer[:, :, att_span:, :]
|
||||
p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
|
||||
p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
|
||||
p2p_att = torch.gather(
|
||||
p2p_att,
|
||||
dim=-1,
|
||||
index=c2p_pos.expand(
|
||||
[query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]
|
||||
),
|
||||
)
|
||||
score += p2p_att
|
||||
|
||||
return score
|
||||
|
||||
|
||||
|
@ -604,14 +604,14 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
||||
self.pos_dropout = TFDebertaV2StableDropout(config.hidden_dropout_prob, name="pos_dropout")
|
||||
|
||||
if not self.share_att_key:
|
||||
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "c2p" in self.pos_att_type:
|
||||
self.pos_proj = tf.keras.layers.Dense(
|
||||
self.all_head_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="pos_proj",
|
||||
use_bias=True,
|
||||
)
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
self.pos_q_proj = tf.keras.layers.Dense(
|
||||
self.all_head_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
@ -679,8 +679,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
||||
scale_factor += 1
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
if "p2p" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, tf.float32))
|
||||
attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1])) / scale
|
||||
if self.relative_attention:
|
||||
@ -749,12 +747,12 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
||||
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
|
||||
)
|
||||
else:
|
||||
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "c2p" in self.pos_att_type:
|
||||
pos_key_layer = tf.tile(
|
||||
self.transpose_for_scores(self.pos_key_proj(rel_embeddings), self.num_attention_heads),
|
||||
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
|
||||
) # .split(self.all_head_size, dim=-1)
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
pos_query_layer = tf.tile(
|
||||
self.transpose_for_scores(self.pos_query_proj(rel_embeddings), self.num_attention_heads),
|
||||
[shape_list(query_layer)[0] // self.num_attention_heads, 1, 1],
|
||||
@ -777,7 +775,7 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
||||
score += c2p_att / scale
|
||||
|
||||
# position->content
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale = tf.math.sqrt(tf.cast(shape_list(pos_query_layer)[-1] * scale_factor, tf.float32))
|
||||
if shape_list(key_layer)[-2] != shape_list(query_layer)[-2]:
|
||||
r_pos = build_relative_position(
|
||||
@ -792,7 +790,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
p2c_pos = tf.clip_by_value(-r_pos + att_span, 0, att_span * 2 - 1)
|
||||
|
||||
if "p2c" in self.pos_att_type:
|
||||
p2c_att = tf.matmul(key_layer, tf.transpose(pos_query_layer, [0, 2, 1]))
|
||||
p2c_att = tf.transpose(
|
||||
take_along_axis(
|
||||
@ -807,26 +804,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
||||
)
|
||||
score += p2c_att / scale
|
||||
|
||||
# position->position
|
||||
if "p2p" in self.pos_att_type:
|
||||
pos_query = pos_query_layer[:, :, att_span:, :]
|
||||
p2p_att = tf.matmul(pos_query, tf.transpose(pos_key_layer, [0, 2, 1]))
|
||||
p2p_att = tf.broadcast_to(shape_list(query_layer)[:2] + shape_list(p2p_att)[2:])
|
||||
p2p_att = take_along_axis(
|
||||
p2p_att,
|
||||
tf.broadcast_to(
|
||||
c2p_pos,
|
||||
[
|
||||
shape_list(query_layer)[0],
|
||||
shape_list(query_layer)[1],
|
||||
shape_list(query_layer)[2],
|
||||
shape_list(relative_pos)[-1],
|
||||
],
|
||||
),
|
||||
-1,
|
||||
)
|
||||
score += p2p_att
|
||||
|
||||
return score
|
||||
|
||||
|
||||
|
@ -66,8 +66,8 @@ class SEWDConfig(PretrainedConfig):
|
||||
position_biased_input (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add absolute position embedding to content embedding.
|
||||
pos_att_type (`Tuple[str]`, *optional*, defaults to `("p2c", "c2p")`):
|
||||
The type of relative position attention, it can be a combination of `("p2c", "c2p", "p2p")`, e.g.
|
||||
`("p2c")`, `("p2c", "c2p")`, `("p2c", "c2p", 'p2p")`.
|
||||
The type of relative position attention, it can be a combination of `("p2c", "c2p")`, e.g. `("p2c")`,
|
||||
`("p2c", "c2p")`, `("p2c", "c2p")`.
|
||||
norm_rel_ebd (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
Whether to use layer norm in relative embedding (`"layer_norm"` if yes)
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`):
|
||||
|
@ -703,9 +703,9 @@ class DisentangledSelfAttention(nn.Module):
|
||||
self.pos_dropout = StableDropout(config.activation_dropout)
|
||||
|
||||
if not self.share_att_key:
|
||||
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "c2p" in self.pos_att_type:
|
||||
self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = StableDropout(config.attention_dropout)
|
||||
@ -766,8 +766,6 @@ class DisentangledSelfAttention(nn.Module):
|
||||
scale_factor += 1
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
if "p2p" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
scale = math.sqrt(query_layer.size(-1) * scale_factor)
|
||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
||||
if self.relative_attention:
|
||||
@ -818,7 +816,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
att_span = self.pos_ebd_size
|
||||
relative_pos = relative_pos.long().to(query_layer.device)
|
||||
|
||||
rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze(0)
|
||||
rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
|
||||
if self.share_att_key:
|
||||
pos_query_layer = self.transpose_for_scores(
|
||||
self.query_proj(rel_embeddings), self.num_attention_heads
|
||||
@ -827,13 +825,13 @@ class DisentangledSelfAttention(nn.Module):
|
||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
||||
)
|
||||
else:
|
||||
if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "c2p" in self.pos_att_type:
|
||||
pos_key_layer = self.transpose_for_scores(
|
||||
self.pos_key_proj(rel_embeddings), self.num_attention_heads
|
||||
).repeat(
|
||||
query_layer.size(0) // self.num_attention_heads, 1, 1
|
||||
) # .split(self.all_head_size, dim=-1)
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
pos_query_layer = self.transpose_for_scores(
|
||||
self.pos_query_proj(rel_embeddings), self.num_attention_heads
|
||||
).repeat(
|
||||
@ -854,7 +852,7 @@ class DisentangledSelfAttention(nn.Module):
|
||||
score += c2p_att / scale
|
||||
|
||||
# position->content
|
||||
if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
|
||||
if key_layer.size(-2) != query_layer.size(-2):
|
||||
r_pos = build_relative_position(
|
||||
@ -868,8 +866,6 @@ class DisentangledSelfAttention(nn.Module):
|
||||
r_pos = relative_pos
|
||||
|
||||
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
|
||||
|
||||
if "p2c" in self.pos_att_type:
|
||||
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
|
||||
p2c_att = torch.gather(
|
||||
p2c_att,
|
||||
@ -878,20 +874,6 @@ class DisentangledSelfAttention(nn.Module):
|
||||
).transpose(-1, -2)
|
||||
score += p2c_att / scale
|
||||
|
||||
# position->position
|
||||
if "p2p" in self.pos_att_type:
|
||||
pos_query = pos_query_layer[:, :, att_span:, :]
|
||||
p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
|
||||
p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
|
||||
p2p_att = torch.gather(
|
||||
p2p_att,
|
||||
dim=-1,
|
||||
index=c2p_pos.expand(
|
||||
[query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]
|
||||
),
|
||||
)
|
||||
score += p2p_att
|
||||
|
||||
return score
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user