mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[T5] Bug correction & Refactor (#8518)
* fix bug * T5 refactor * refactor tf * apply sylvains suggestions
This commit is contained in:
parent
42f63e3871
commit
42e2d02e44
@ -75,7 +75,6 @@ class T5Config(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32128,
|
||||
n_positions=512,
|
||||
d_model=512,
|
||||
d_kv=64,
|
||||
d_ff=2048,
|
||||
@ -98,7 +97,6 @@ class T5Config(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.d_model = d_model
|
||||
self.d_kv = d_kv
|
||||
self.d_ff = d_ff
|
||||
@ -112,10 +110,6 @@ class T5Config(PretrainedConfig):
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
return self.n_positions
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.d_model
|
||||
|
@ -17,8 +17,6 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
|
||||
@ -37,7 +35,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -115,12 +115,12 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
||||
scope_names = [m_name]
|
||||
if scope_names[0] in ["kernel", "scale", "embedding"]:
|
||||
pointer = getattr(pointer, "weight")
|
||||
# elif scope_names[0] == 'scale':
|
||||
# pointer = getattr(pointer, 'weight')
|
||||
# elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta':
|
||||
# pointer = getattr(pointer, 'bias')
|
||||
# elif scope_names[0] == 'squad':
|
||||
# pointer = getattr(pointer, 'classifier')
|
||||
elif scope_names[0] == "scale":
|
||||
pointer = getattr(pointer, "weight")
|
||||
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif scope_names[0] == "squad":
|
||||
pointer = getattr(pointer, "classifier")
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
@ -147,7 +147,6 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
||||
tf_weights.pop(txt_name, None)
|
||||
|
||||
logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
|
||||
# logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
return model
|
||||
|
||||
|
||||
@ -167,14 +166,15 @@ class T5LayerNorm(nn.Module):
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, hidden_states):
|
||||
# layer norm should always be calculated in float32
|
||||
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
x = x / torch.sqrt(variance + self.variance_epsilon)
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into float16 if necessary
|
||||
if self.weight.dtype == torch.float16:
|
||||
x = x.to(torch.float16)
|
||||
return self.weight * x
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class T5DenseReluDense(nn.Module):
|
||||
@ -185,11 +185,11 @@ class T5DenseReluDense(nn.Module):
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
h = self.wi(hidden_states)
|
||||
h = F.relu(h)
|
||||
h = self.dropout(h)
|
||||
h = self.wo(h)
|
||||
return h
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = F.relu(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerFF(nn.Module):
|
||||
@ -200,25 +200,24 @@ class T5LayerFF(nn.Module):
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
y = self.DenseReluDense(norm_x)
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
return layer_output
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = hidden_states + self.dropout(forwarded_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
def __init__(self, config: T5Config, has_relative_attention_bias=False, is_bidirectional=False):
|
||||
def __init__(self, config: T5Config, has_relative_attention_bias=False):
|
||||
super().__init__()
|
||||
self.is_bidirectional = is_bidirectional
|
||||
self.is_decoder = config.is_decoder
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
self.d_model = config.d_model
|
||||
self.d_kv = config.d_kv
|
||||
self.key_value_proj_dim = config.d_kv
|
||||
self.n_heads = config.num_heads
|
||||
self.dropout = config.dropout_rate
|
||||
self.inner_dim = self.n_heads * self.d_kv
|
||||
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
||||
@ -233,7 +232,9 @@ class T5Attention(nn.Module):
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads)
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
|
||||
)
|
||||
# Prune linear layers
|
||||
self.q = prune_linear_layer(self.q, index)
|
||||
self.k = prune_linear_layer(self.k, index)
|
||||
@ -241,7 +242,7 @@ class T5Attention(nn.Module):
|
||||
self.o = prune_linear_layer(self.o, index, dim=1)
|
||||
# Update hyper params
|
||||
self.n_heads = self.n_heads - len(heads)
|
||||
self.inner_dim = self.d_kv * self.n_heads
|
||||
self.inner_dim = self.key_value_proj_dim * self.n_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@staticmethod
|
||||
@ -266,49 +267,52 @@ class T5Attention(nn.Module):
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
ret = 0
|
||||
n = -relative_position
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
|
||||
n = torch.abs(n)
|
||||
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
||||
relative_position = torch.abs(relative_position)
|
||||
else:
|
||||
n = torch.max(n, torch.zeros_like(n))
|
||||
# now n is in the range [0, inf)
|
||||
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = n < max_exact
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
val_if_large = max_exact + (
|
||||
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
relative_postion_if_large = max_exact + (
|
||||
torch.log(relative_position.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(torch.long)
|
||||
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||
relative_postion_if_large = torch.min(
|
||||
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||
)
|
||||
|
||||
ret += torch.where(is_small, n, val_if_large)
|
||||
return ret
|
||||
relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, qlen, klen):
|
||||
def compute_bias(self, query_length, key_length):
|
||||
""" Compute binned relative position bias """
|
||||
context_position = torch.arange(qlen, dtype=torch.long)[:, None]
|
||||
memory_position = torch.arange(klen, dtype=torch.long)[None, :]
|
||||
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||
rp_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (qlen, klen)
|
||||
bidirectional=self.is_bidirectional,
|
||||
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=(not self.is_decoder),
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
)
|
||||
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
hidden_states,
|
||||
mask=None,
|
||||
kv=None,
|
||||
key_value_states=None,
|
||||
position_bias=None,
|
||||
past_key_value=None,
|
||||
head_mask=None,
|
||||
@ -317,106 +321,113 @@ class T5Attention(nn.Module):
|
||||
output_attentions=False,
|
||||
):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
||||
"""
|
||||
# Input is (bs, qlen, dim)
|
||||
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||
# past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head)
|
||||
bs, qlen, dim = input.size()
|
||||
# Input is (batch_size, seq_length, dim)
|
||||
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
||||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
real_seq_length = seq_length
|
||||
|
||||
if past_key_value is not None:
|
||||
assert self.is_decoder is True, "Encoder cannot cache past key value states"
|
||||
assert (
|
||||
len(past_key_value) == 2
|
||||
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
|
||||
len(past_key_value)
|
||||
)
|
||||
real_qlen = qlen + past_key_value[0].shape[2] if query_length is None else query_length
|
||||
else:
|
||||
real_qlen = qlen
|
||||
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
|
||||
|
||||
if kv is None:
|
||||
klen = real_qlen
|
||||
else:
|
||||
klen = kv.size(1)
|
||||
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
||||
|
||||
def shape(x):
|
||||
def shape(states):
|
||||
""" projection """
|
||||
return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2)
|
||||
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||
|
||||
def unshape(x):
|
||||
""" compute context """
|
||||
return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim)
|
||||
def unshape(states):
|
||||
""" reshape """
|
||||
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
||||
|
||||
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
||||
""" projects hidden states correctly to key/query states """
|
||||
if key_value_states is None:
|
||||
# self-attn
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
hidden_states = shape(proj_layer(hidden_states))
|
||||
elif past_key_value is None:
|
||||
# cross-attn
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
hidden_states = shape(proj_layer(key_value_states))
|
||||
|
||||
if kv is None:
|
||||
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
elif past_key_value is None:
|
||||
k = v = kv
|
||||
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
|
||||
if past_key_value is not None:
|
||||
if key_value_states is None:
|
||||
# self-attn
|
||||
# (batch_size, n_heads, key_length, dim_per_head)
|
||||
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
||||
else:
|
||||
# cross-attn
|
||||
hidden_states = past_key_value
|
||||
return hidden_states
|
||||
|
||||
if past_key_value is not None:
|
||||
if kv is None:
|
||||
k_, v_ = past_key_value
|
||||
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||
else:
|
||||
k, v = past_key_value
|
||||
# get query states
|
||||
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
|
||||
|
||||
if self.is_decoder and use_cache is True:
|
||||
present_key_value_state = ((k, v),)
|
||||
else:
|
||||
present_key_value_state = (None,)
|
||||
# get key/value states
|
||||
key_states = project(
|
||||
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
|
||||
)
|
||||
value_states = project(
|
||||
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
|
||||
)
|
||||
|
||||
# (bs, n_heads, qlen, klen)
|
||||
# compute scores
|
||||
scores = torch.matmul(
|
||||
q, k.transpose(3, 2)
|
||||
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", q, k), compatible with onnx op>9
|
||||
query_states, key_states.transpose(3, 2)
|
||||
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
||||
|
||||
if position_bias is None:
|
||||
if not self.has_relative_attention_bias:
|
||||
raise ValueError("No position_bias provided and no weights to compute position_bias")
|
||||
position_bias = self.compute_bias(real_qlen, klen)
|
||||
position_bias = torch.zeros(
|
||||
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
else:
|
||||
position_bias = self.compute_bias(real_seq_length, key_length)
|
||||
|
||||
# if key and values are already calculated
|
||||
# we want only the last query position bias
|
||||
if past_key_value is not None:
|
||||
position_bias = position_bias[:, :, -qlen:, :]
|
||||
position_bias = position_bias[:, :, -seq_length:, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
scores += position_bias
|
||||
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
|
||||
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
|
||||
attn_weights = F.softmax(scores.float(), dim=-1).type_as(
|
||||
scores
|
||||
) # (batch_size, n_heads, seq_length, key_length)
|
||||
attn_weights = F.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
) # (batch_size, n_heads, seq_length, key_length)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
weights = weights * head_mask
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
||||
context = unshape(context) # (bs, qlen, dim)
|
||||
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
|
||||
attn_output = self.o(attn_output)
|
||||
|
||||
context = self.o(context)
|
||||
|
||||
outputs = (context,) + present_key_value_state
|
||||
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
|
||||
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
if self.has_relative_attention_bias:
|
||||
outputs = outputs + (position_bias,)
|
||||
outputs = outputs + (attn_weights,)
|
||||
return outputs
|
||||
|
||||
|
||||
class T5LayerSelfAttention(nn.Module):
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
super().__init__()
|
||||
self.SelfAttention = T5Attention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=not config.is_decoder
|
||||
)
|
||||
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
|
||||
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
@ -430,9 +441,9 @@ class T5LayerSelfAttention(nn.Module):
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.SelfAttention(
|
||||
norm_x,
|
||||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
@ -440,25 +451,22 @@ class T5LayerSelfAttention(nn.Module):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
hidden_states = hidden_states + self.dropout(attention_output[0])
|
||||
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class T5LayerCrossAttention(nn.Module):
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.EncDecAttention = T5Attention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=True
|
||||
)
|
||||
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
|
||||
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
kv,
|
||||
key_value_states,
|
||||
attention_mask=None,
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
@ -467,11 +475,11 @@ class T5LayerCrossAttention(nn.Module):
|
||||
query_length=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.EncDecAttention(
|
||||
norm_x,
|
||||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
kv=kv,
|
||||
key_value_states=key_value_states,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value=past_key_value,
|
||||
@ -479,8 +487,7 @@ class T5LayerCrossAttention(nn.Module):
|
||||
query_length=query_length,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
layer_output = hidden_states + self.dropout(attention_output[0])
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
@ -492,7 +499,7 @@ class T5Block(nn.Module):
|
||||
self.layer = nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
|
||||
if self.is_decoder:
|
||||
self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias))
|
||||
self.layer.append(T5LayerCrossAttention(config))
|
||||
|
||||
self.layer.append(T5LayerFF(config))
|
||||
|
||||
@ -550,7 +557,7 @@ class T5Block(nn.Module):
|
||||
|
||||
cross_attention_outputs = self.layer[1](
|
||||
hidden_states,
|
||||
kv=encoder_hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask,
|
||||
@ -619,12 +626,12 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
# Mesh TensorFlow attention initialization to avoid scaling before softmax
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
|
||||
d_model = self.config.d_model
|
||||
d_kv = self.config.d_kv
|
||||
key_value_proj_dim = self.config.d_kv
|
||||
n_heads = self.config.num_heads
|
||||
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * d_kv) ** -0.5))
|
||||
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
|
||||
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
|
||||
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
|
||||
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * d_kv) ** -0.5))
|
||||
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
|
||||
if module.has_relative_attention_bias:
|
||||
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
||||
|
||||
@ -775,20 +782,20 @@ class T5Stack(T5PreTrainedModel):
|
||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||
|
||||
if i == 0:
|
||||
# We share the position biases between the layers - the first layer store them
|
||||
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
position_bias = layer_outputs[3 if output_attentions else 2]
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
|
||||
# We share the position biases between the layers - the first layer store them
|
||||
# layer_outputs = hidden-states, key-value-states (self-attention weights),
|
||||
# (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
position_bias = layer_outputs[2]
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
||||
# append next layer key value states
|
||||
if use_cache:
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[2],)
|
||||
all_attentions = all_attentions + (layer_outputs[3],)
|
||||
if self.is_decoder:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[4 if i == 0 else 3],)
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
@ -920,6 +927,12 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class T5Model(T5PreTrainedModel):
|
||||
authorized_missing_keys = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__(config)
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
@ -1063,7 +1076,14 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
||||
class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"]
|
||||
authorized_missing_keys = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"lm_head\.weight",
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -81,10 +81,10 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
|
||||
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
|
||||
super().build(input_shape)
|
||||
|
||||
def call(self, x):
|
||||
variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True)
|
||||
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * x
|
||||
def call(self, hidden_states):
|
||||
variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True)
|
||||
hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class TFT5DenseReluDense(tf.keras.layers.Layer):
|
||||
@ -96,11 +96,11 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
|
||||
self.act = tf.keras.activations.relu
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
h = self.wi(hidden_states)
|
||||
h = self.act(h)
|
||||
h = self.dropout(h, training=training)
|
||||
h = self.wo(h)
|
||||
return h
|
||||
hidden_states = self.wi(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFT5LayerFF(tf.keras.layers.Layer):
|
||||
@ -111,18 +111,17 @@ class TFT5LayerFF(tf.keras.layers.Layer):
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
def call(self, hidden_states, training=False):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
y = self.DenseReluDense(norm_x, training=training)
|
||||
layer_output = hidden_states + self.dropout(y, training=training)
|
||||
return layer_output
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
dense_output = self.DenseReluDense(normed_hidden_states, training=training)
|
||||
hidden_states = hidden_states + self.dropout(dense_output, training=training)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TFT5Attention(tf.keras.layers.Layer):
|
||||
NEW_ID = itertools.count()
|
||||
|
||||
def __init__(self, config, has_relative_attention_bias=False, is_bidirectional=False, **kwargs):
|
||||
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.is_bidirectional = is_bidirectional
|
||||
self.layer_id = next(TFT5Attention.NEW_ID)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.use_cache = config.use_cache
|
||||
@ -131,9 +130,9 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
self.d_model = config.d_model
|
||||
self.d_kv = config.d_kv
|
||||
self.key_value_proj_dim = config.d_kv
|
||||
self.n_heads = config.num_heads
|
||||
self.inner_dim = self.n_heads * self.d_kv
|
||||
self.inner_dim = self.n_heads * self.key_value_proj_dim
|
||||
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
|
||||
@ -175,46 +174,48 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||
"""
|
||||
ret = 0
|
||||
n = -relative_position
|
||||
relative_buckets = 0
|
||||
# n = -relative_position
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets
|
||||
n = tf.math.abs(n)
|
||||
relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets
|
||||
relative_position = tf.math.abs(relative_position)
|
||||
else:
|
||||
n = tf.math.maximum(n, 0)
|
||||
relative_position = -tf.math.minimum(relative_position, 0)
|
||||
# now n is in the range [0, inf)
|
||||
max_exact = num_buckets // 2
|
||||
is_small = tf.math.less(n, max_exact)
|
||||
val_if_large = max_exact + tf.dtypes.cast(
|
||||
tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact)
|
||||
is_small = tf.math.less(relative_position, max_exact)
|
||||
relative_position_if_large = max_exact + tf.dtypes.cast(
|
||||
tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact)
|
||||
/ math.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact),
|
||||
tf.int32,
|
||||
)
|
||||
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
|
||||
ret += tf.where(is_small, n, val_if_large)
|
||||
return ret
|
||||
relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1)
|
||||
relative_buckets += tf.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
|
||||
def compute_bias(self, qlen, klen):
|
||||
def compute_bias(self, query_length, key_length):
|
||||
""" Compute binned relative position bias """
|
||||
context_position = tf.range(qlen)[:, None]
|
||||
memory_position = tf.range(klen)[None, :]
|
||||
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||
rp_bucket = self._relative_position_bucket(
|
||||
context_position = tf.range(query_length)[:, None]
|
||||
memory_position = tf.range(key_length)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = self._relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=self.is_bidirectional,
|
||||
bidirectional=(not self.is_decoder),
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
)
|
||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = tf.expand_dims(
|
||||
tf.transpose(values, [2, 0, 1]), axis=0
|
||||
) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
def call(
|
||||
self,
|
||||
input,
|
||||
hidden_states,
|
||||
mask=None,
|
||||
kv=None,
|
||||
key_value_states=None,
|
||||
position_bias=None,
|
||||
past_key_value=None,
|
||||
head_mask=None,
|
||||
@ -224,95 +225,108 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
output_attentions=False,
|
||||
):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
||||
"""
|
||||
# Input is (bs, qlen, dim)
|
||||
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||
# past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head)
|
||||
bs, qlen, dim = shape_list(input)
|
||||
# Input is (batch_size, query_length, dim)
|
||||
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
|
||||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||
batch_size, seq_length = shape_list(hidden_states)[:2]
|
||||
|
||||
real_seq_length = seq_length
|
||||
|
||||
if past_key_value is not None:
|
||||
assert self.is_decoder is True, "Encoder cannot cache past key value states"
|
||||
assert (
|
||||
len(past_key_value) == 2
|
||||
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
|
||||
len(past_key_value)
|
||||
)
|
||||
real_qlen = qlen + shape_list(past_key_value[0])[2] if query_length is None else query_length
|
||||
else:
|
||||
real_qlen = qlen
|
||||
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
|
||||
|
||||
if kv is None:
|
||||
klen = real_qlen
|
||||
else:
|
||||
klen = shape_list(kv)[1]
|
||||
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
||||
|
||||
def shape(x):
|
||||
def shape(hidden_states):
|
||||
""" projection """
|
||||
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, self.d_kv)), perm=(0, 2, 1, 3))
|
||||
return tf.transpose(
|
||||
tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3)
|
||||
)
|
||||
|
||||
def unshape(x):
|
||||
def unshape(hidden_states):
|
||||
""" compute context """
|
||||
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim))
|
||||
return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim))
|
||||
|
||||
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
||||
""" projects hidden states correctly to key/query states """
|
||||
if key_value_states is None:
|
||||
# self-attn
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
hidden_states = shape(proj_layer(hidden_states))
|
||||
elif past_key_value is None:
|
||||
# cross-attn
|
||||
# (batch_size, n_heads, seq_length, dim_per_head)
|
||||
hidden_states = shape(proj_layer(key_value_states))
|
||||
|
||||
if kv is None:
|
||||
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
elif past_key_value is None:
|
||||
k = v = kv
|
||||
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
|
||||
if past_key_value is not None:
|
||||
if key_value_states is None:
|
||||
# self-attn
|
||||
# (batch_size, n_heads, key_length, dim_per_head)
|
||||
hidden_states = tf.concat([past_key_value, hidden_states], axis=2)
|
||||
else:
|
||||
# cross-attn
|
||||
hidden_states = past_key_value
|
||||
return hidden_states
|
||||
|
||||
if past_key_value is not None:
|
||||
if kv is None:
|
||||
k_, v_ = past_key_value
|
||||
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
|
||||
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
|
||||
else:
|
||||
k, v = past_key_value
|
||||
# get query
|
||||
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head)
|
||||
|
||||
# get key/value
|
||||
key_states = project(
|
||||
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
|
||||
)
|
||||
value_states = project(
|
||||
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
|
||||
)
|
||||
|
||||
# to cope with keras serialization
|
||||
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
|
||||
present_key_value_state = ((k, v),)
|
||||
present_key_value_state = (key_states, value_states)
|
||||
else:
|
||||
present_key_value_state = (None,)
|
||||
present_key_value_state = None
|
||||
|
||||
scores = tf.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
|
||||
scores = tf.einsum(
|
||||
"bnqd,bnkd->bnqk", query_states, key_states
|
||||
) # (batch_size, n_heads, query_length, key_length)
|
||||
|
||||
if position_bias is None:
|
||||
if not self.has_relative_attention_bias:
|
||||
raise ValueError("No position_bias provided and no weights to compute position_bias")
|
||||
position_bias = self.compute_bias(real_qlen, klen)
|
||||
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32)
|
||||
else:
|
||||
position_bias = self.compute_bias(real_seq_length, key_length)
|
||||
|
||||
# if key and values are already calculated
|
||||
# we want only the last query position bias
|
||||
if past_key_value is not None:
|
||||
position_bias = position_bias[:, :, -qlen:, :]
|
||||
position_bias = position_bias[:, :, -seq_length:, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
|
||||
|
||||
scores += position_bias
|
||||
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
|
||||
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
|
||||
weights = tf.nn.softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length)
|
||||
weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
weights = weights * head_mask
|
||||
|
||||
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
||||
context = unshape(context) # (bs, qlen, dim)
|
||||
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
|
||||
|
||||
context = self.o(context)
|
||||
attn_output = self.o(unshape(attn_output))
|
||||
|
||||
outputs = (context,) + present_key_value_state
|
||||
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
|
||||
|
||||
if output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
if self.has_relative_attention_bias:
|
||||
outputs = outputs + (position_bias,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@ -322,7 +336,6 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
|
||||
self.SelfAttention = TFT5Attention(
|
||||
config,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
is_bidirectional=not config.is_decoder,
|
||||
name="SelfAttention",
|
||||
)
|
||||
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
|
||||
@ -339,9 +352,9 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
|
||||
output_attentions=False,
|
||||
training=False,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.SelfAttention(
|
||||
norm_x,
|
||||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
@ -350,19 +363,17 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y, training=training)
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
hidden_states = hidden_states + self.dropout(attention_output[0], training=training)
|
||||
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.EncDecAttention = TFT5Attention(
|
||||
config,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
is_bidirectional=True,
|
||||
has_relative_attention_bias=False,
|
||||
name="EncDecAttention",
|
||||
)
|
||||
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
|
||||
@ -371,7 +382,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
kv,
|
||||
key_value_states,
|
||||
attention_mask=None,
|
||||
position_bias=None,
|
||||
head_mask=None,
|
||||
@ -381,11 +392,11 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
||||
output_attentions=False,
|
||||
training=False,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.EncDecAttention(
|
||||
norm_x,
|
||||
normed_hidden_states,
|
||||
mask=attention_mask,
|
||||
kv=kv,
|
||||
key_value_states=key_value_states,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask,
|
||||
past_key_value=past_key_value,
|
||||
@ -394,9 +405,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
||||
output_attentions=output_attentions,
|
||||
training=training,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y, training=training)
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
hidden_states = hidden_states + self.dropout(attention_output[0], training=training)
|
||||
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
@ -416,7 +426,6 @@ class TFT5Block(tf.keras.layers.Layer):
|
||||
self.layer.append(
|
||||
TFT5LayerCrossAttention(
|
||||
config,
|
||||
has_relative_attention_bias=has_relative_attention_bias,
|
||||
name="layer_._1",
|
||||
)
|
||||
)
|
||||
@ -477,7 +486,7 @@ class TFT5Block(tf.keras.layers.Layer):
|
||||
|
||||
cross_attention_outputs = self.layer[1](
|
||||
hidden_states,
|
||||
kv=encoder_hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask,
|
||||
@ -731,17 +740,18 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
# layer_outputs is a tuple with:
|
||||
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
hidden_states, present_key_value_state = layer_outputs[:2]
|
||||
if i == 0:
|
||||
# We share the position biases between the layers - the first layer store them
|
||||
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||
position_bias = layer_outputs[3 if output_attentions else 2]
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
|
||||
|
||||
# We share the position biases between the layers - the first layer store them
|
||||
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
|
||||
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
|
||||
position_bias = layer_outputs[2]
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
|
||||
# append next layer key value states
|
||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[2],)
|
||||
all_attentions = all_attentions + (layer_outputs[3],)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, training=training)
|
||||
|
@ -461,7 +461,7 @@ class ModelTesterMixin:
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||
inputs["head_mask"] = head_mask
|
||||
|
||||
outputs = model(**inputs)
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
|
||||
# Test that we can get a gradient back for importance score computation
|
||||
output = sum(t.sum() for t in outputs[0])
|
||||
|
@ -556,6 +556,32 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
def tokenizer(self):
|
||||
return T5Tokenizer.from_pretrained("t5-base")
|
||||
|
||||
@slow
|
||||
def test_small_integration_test(self):
|
||||
"""
|
||||
For comparision run:
|
||||
>>> import t5 # pip install t5==0.7.1
|
||||
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
|
||||
|
||||
>>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
|
||||
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
|
||||
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
|
||||
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
|
||||
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
|
||||
"""
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", return_dict=True).to(torch_device)
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
|
||||
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
|
||||
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
|
||||
|
||||
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
|
||||
mtf_score = -(labels.shape[-1] * loss.item())
|
||||
|
||||
EXPECTED_SCORE = -19.0845
|
||||
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
|
||||
|
||||
@slow
|
||||
def test_summarization(self):
|
||||
model = self.model
|
||||
@ -567,8 +593,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
|
||||
|
||||
expected_summaries = [
|
||||
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .',
|
||||
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .",
|
||||
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
|
||||
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
|
||||
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
|
||||
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
|
||||
]
|
||||
|
@ -311,8 +311,8 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
|
||||
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
|
||||
|
||||
expected_summaries = [
|
||||
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .',
|
||||
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .",
|
||||
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
|
||||
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
|
||||
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
|
||||
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user