[T5] Bug correction & Refactor (#8518)

* fix bug

* T5 refactor

* refactor tf

* apply sylvains suggestions
This commit is contained in:
Patrick von Platen 2020-11-13 16:57:31 +01:00 committed by GitHub
parent 42f63e3871
commit 42e2d02e44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 310 additions and 262 deletions

View File

@ -75,7 +75,6 @@ class T5Config(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=32128, vocab_size=32128,
n_positions=512,
d_model=512, d_model=512,
d_kv=64, d_kv=64,
d_ff=2048, d_ff=2048,
@ -98,7 +97,6 @@ class T5Config(PretrainedConfig):
**kwargs, **kwargs,
) )
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model self.d_model = d_model
self.d_kv = d_kv self.d_kv = d_kv
self.d_ff = d_ff self.d_ff = d_ff
@ -112,10 +110,6 @@ class T5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
@property
def max_position_embeddings(self):
return self.n_positions
@property @property
def hidden_size(self): def hidden_size(self):
return self.d_model return self.d_model

View File

@ -17,8 +17,6 @@
import argparse import argparse
import torch
from transformers import T5Config, T5Model, load_tf_weights_in_t5 from transformers import T5Config, T5Model, load_tf_weights_in_t5
from transformers.utils import logging from transformers.utils import logging
@ -37,7 +35,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
# Save pytorch-model # Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path)) print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path) model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -115,12 +115,12 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
scope_names = [m_name] scope_names = [m_name]
if scope_names[0] in ["kernel", "scale", "embedding"]: if scope_names[0] in ["kernel", "scale", "embedding"]:
pointer = getattr(pointer, "weight") pointer = getattr(pointer, "weight")
# elif scope_names[0] == 'scale': elif scope_names[0] == "scale":
# pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
# elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
# pointer = getattr(pointer, 'bias') pointer = getattr(pointer, "bias")
# elif scope_names[0] == 'squad': elif scope_names[0] == "squad":
# pointer = getattr(pointer, 'classifier') pointer = getattr(pointer, "classifier")
else: else:
try: try:
pointer = getattr(pointer, scope_names[0]) pointer = getattr(pointer, scope_names[0])
@ -147,7 +147,6 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
tf_weights.pop(txt_name, None) tf_weights.pop(txt_name, None)
logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys()))) logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
# logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
return model return model
@ -167,14 +166,15 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, x): def forward(self, hidden_states):
# layer norm should always be calculated in float32 # layer norm should always be calculated in float32
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x / torch.sqrt(variance + self.variance_epsilon) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into float16 if necessary
if self.weight.dtype == torch.float16: if self.weight.dtype == torch.float16:
x = x.to(torch.float16) hidden_states = hidden_states.to(torch.float16)
return self.weight * x return self.weight * hidden_states
class T5DenseReluDense(nn.Module): class T5DenseReluDense(nn.Module):
@ -185,11 +185,11 @@ class T5DenseReluDense(nn.Module):
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states): def forward(self, hidden_states):
h = self.wi(hidden_states) hidden_states = self.wi(hidden_states)
h = F.relu(h) hidden_states = F.relu(hidden_states)
h = self.dropout(h) hidden_states = self.dropout(hidden_states)
h = self.wo(h) hidden_states = self.wo(hidden_states)
return h return hidden_states
class T5LayerFF(nn.Module): class T5LayerFF(nn.Module):
@ -200,25 +200,24 @@ class T5LayerFF(nn.Module):
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states): def forward(self, hidden_states):
norm_x = self.layer_norm(hidden_states) forwarded_states = self.layer_norm(hidden_states)
y = self.DenseReluDense(norm_x) forwarded_states = self.DenseReluDense(forwarded_states)
layer_output = hidden_states + self.dropout(y) hidden_states = hidden_states + self.dropout(forwarded_states)
return layer_output return hidden_states
class T5Attention(nn.Module): class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False, is_bidirectional=False): def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__() super().__init__()
self.is_bidirectional = is_bidirectional
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model self.d_model = config.d_model
self.d_kv = config.d_kv self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.d_kv self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@ -233,7 +232,9 @@ class T5Attention(nn.Module):
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads) heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
)
# Prune linear layers # Prune linear layers
self.q = prune_linear_layer(self.q, index) self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index) self.k = prune_linear_layer(self.k, index)
@ -241,7 +242,7 @@ class T5Attention(nn.Module):
self.o = prune_linear_layer(self.o, index, dim=1) self.o = prune_linear_layer(self.o, index, dim=1)
# Update hyper params # Update hyper params
self.n_heads = self.n_heads - len(heads) self.n_heads = self.n_heads - len(heads)
self.inner_dim = self.d_kv * self.n_heads self.inner_dim = self.key_value_proj_dim * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
@staticmethod @staticmethod
@ -266,49 +267,52 @@ class T5Attention(nn.Module):
Returns: Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
""" """
ret = 0 relative_buckets = 0
n = -relative_position
if bidirectional: if bidirectional:
num_buckets //= 2 num_buckets //= 2
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
n = torch.abs(n) relative_position = torch.abs(relative_position)
else: else:
n = torch.max(n, torch.zeros_like(n)) relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now n is in the range [0, inf) # now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions # half of the buckets are for exact increments in positions
max_exact = num_buckets // 2 max_exact = num_buckets // 2
is_small = n < max_exact is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large = max_exact + ( relative_postion_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long) ).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
ret += torch.where(is_small, n, val_if_large) relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
return ret return relative_buckets
def compute_bias(self, qlen, klen): def compute_bias(self, query_length, key_length):
""" Compute binned relative position bias """ """ Compute binned relative position bias """
context_position = torch.arange(qlen, dtype=torch.long)[:, None] context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(klen, dtype=torch.long)[None, :] memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen) relative_position = memory_position - context_position # shape (query_length, key_length)
rp_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (qlen, klen) relative_position, # shape (query_length, key_length)
bidirectional=self.is_bidirectional, bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets, num_buckets=self.relative_attention_num_buckets,
) )
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values return values
def forward( def forward(
self, self,
input, hidden_states,
mask=None, mask=None,
kv=None, key_value_states=None,
position_bias=None, position_bias=None,
past_key_value=None, past_key_value=None,
head_mask=None, head_mask=None,
@ -317,106 +321,113 @@ class T5Attention(nn.Module):
output_attentions=False, output_attentions=False,
): ):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (bs, qlen, dim) # Input is (batch_size, seq_length, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
bs, qlen, dim = input.size() batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
if past_key_value is not None: if past_key_value is not None:
assert self.is_decoder is True, "Encoder cannot cache past key value states"
assert ( assert (
len(past_key_value) == 2 len(past_key_value) == 2
), "past_key_value should have 2 past states: keys and values. Got {} past states".format( ), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value) len(past_key_value)
) )
real_qlen = qlen + past_key_value[0].shape[2] if query_length is None else query_length real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
else:
real_qlen = qlen
if kv is None: key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
klen = real_qlen
else:
klen = kv.size(1)
def shape(x): def shape(states):
""" projection """ """ projection """
return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2) return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(x): def unshape(states):
""" compute context """ """ reshape """
return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim) return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) def project(hidden_states, proj_layer, key_value_states, past_key_value):
""" projects hidden states correctly to key/query states """
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if kv is None: if past_key_value is not None:
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) if key_value_states is None:
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) # self-attn
elif past_key_value is None: # (batch_size, n_heads, key_length, dim_per_head)
k = v = kv hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) else:
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) # cross-attn
hidden_states = past_key_value
return hidden_states
if past_key_value is not None: # get query states
if kv is None: query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
k_, v_ = past_key_value
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = past_key_value
if self.is_decoder and use_cache is True: # get key/value states
present_key_value_state = ((k, v),) key_states = project(
else: hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
present_key_value_state = (None,) )
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
# (bs, n_heads, qlen, klen) # compute scores
scores = torch.matmul( scores = torch.matmul(
q, k.transpose(3, 2) query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", q, k), compatible with onnx op>9 ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None: if position_bias is None:
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias") position_bias = torch.zeros(
position_bias = self.compute_bias(real_qlen, klen) (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
else:
position_bias = self.compute_bias(real_seq_length, key_length)
# if key and values are already calculated # if key and values are already calculated
# we want only the last query position bias # we want only the last query position bias
if past_key_value is not None: if past_key_value is not None:
position_bias = position_bias[:, :, -qlen:, :] position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (bs, n_heads, qlen, klen) position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
scores += position_bias scores += position_bias
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) attn_weights = F.softmax(scores.float(), dim=-1).type_as(
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = F.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if head_mask is not None:
weights = weights * head_mask attn_weights = attn_weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
context = unshape(context) # (bs, qlen, dim) attn_output = self.o(attn_output)
context = self.o(context) present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
outputs = (context,) + present_key_value_state
if output_attentions: if output_attentions:
outputs = outputs + (weights,) outputs = outputs + (attn_weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
return outputs return outputs
class T5LayerSelfAttention(nn.Module): class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False):
super().__init__() super().__init__()
self.SelfAttention = T5Attention( self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=not config.is_decoder
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@ -430,9 +441,9 @@ class T5LayerSelfAttention(nn.Module):
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
): ):
norm_x = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
norm_x, normed_hidden_states,
mask=attention_mask, mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, head_mask=head_mask,
@ -440,25 +451,22 @@ class T5LayerSelfAttention(nn.Module):
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
y = attention_output[0] hidden_states = hidden_states + self.dropout(attention_output[0])
layer_output = hidden_states + self.dropout(y) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs return outputs
class T5LayerCrossAttention(nn.Module): class T5LayerCrossAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config):
super().__init__() super().__init__()
self.EncDecAttention = T5Attention( self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=True
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
def forward( def forward(
self, self,
hidden_states, hidden_states,
kv, key_value_states,
attention_mask=None, attention_mask=None,
position_bias=None, position_bias=None,
head_mask=None, head_mask=None,
@ -467,11 +475,11 @@ class T5LayerCrossAttention(nn.Module):
query_length=None, query_length=None,
output_attentions=False, output_attentions=False,
): ):
norm_x = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
norm_x, normed_hidden_states,
mask=attention_mask, mask=attention_mask,
kv=kv, key_value_states=key_value_states,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, head_mask=head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
@ -479,8 +487,7 @@ class T5LayerCrossAttention(nn.Module):
query_length=query_length, query_length=query_length,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
y = attention_output[0] layer_output = hidden_states + self.dropout(attention_output[0])
layer_output = hidden_states + self.dropout(y)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs return outputs
@ -492,7 +499,7 @@ class T5Block(nn.Module):
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder: if self.is_decoder:
self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(T5LayerCrossAttention(config))
self.layer.append(T5LayerFF(config)) self.layer.append(T5LayerFF(config))
@ -550,7 +557,7 @@ class T5Block(nn.Module):
cross_attention_outputs = self.layer[1]( cross_attention_outputs = self.layer[1](
hidden_states, hidden_states,
kv=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
head_mask=head_mask, head_mask=head_mask,
@ -619,12 +626,12 @@ class T5PreTrainedModel(PreTrainedModel):
# Mesh TensorFlow attention initialization to avoid scaling before softmax # Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
d_model = self.config.d_model d_model = self.config.d_model
d_kv = self.config.d_kv key_value_proj_dim = self.config.d_kv
n_heads = self.config.num_heads n_heads = self.config.num_heads
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * d_kv) ** -0.5)) module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * d_kv) ** -0.5)) module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
if module.has_relative_attention_bias: if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
@ -775,20 +782,20 @@ class T5Stack(T5PreTrainedModel):
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, present_key_value_state = layer_outputs[:2]
if i == 0: # We share the position biases between the layers - the first layer store them
# We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights),
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if output_attentions else 2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states # append next layer key value states
if use_cache: if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],) all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder: if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[4 if i == 0 else 3],) all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
@ -920,6 +927,12 @@ T5_INPUTS_DOCSTRING = r"""
T5_START_DOCSTRING, T5_START_DOCSTRING,
) )
class T5Model(T5PreTrainedModel): class T5Model(T5PreTrainedModel):
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__(config) super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model) self.shared = nn.Embedding(config.vocab_size, config.d_model)
@ -1063,7 +1076,14 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel): class T5ForConditionalGeneration(T5PreTrainedModel):
authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"] authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)

View File

@ -81,10 +81,10 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones") self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
super().build(input_shape) super().build(input_shape)
def call(self, x): def call(self, hidden_states):
variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True) variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * x return self.weight * hidden_states
class TFT5DenseReluDense(tf.keras.layers.Layer): class TFT5DenseReluDense(tf.keras.layers.Layer):
@ -96,11 +96,11 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
self.act = tf.keras.activations.relu self.act = tf.keras.activations.relu
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
h = self.wi(hidden_states) hidden_states = self.wi(hidden_states)
h = self.act(h) hidden_states = self.act(hidden_states)
h = self.dropout(h, training=training) hidden_states = self.dropout(hidden_states, training=training)
h = self.wo(h) hidden_states = self.wo(hidden_states)
return h return hidden_states
class TFT5LayerFF(tf.keras.layers.Layer): class TFT5LayerFF(tf.keras.layers.Layer):
@ -111,18 +111,17 @@ class TFT5LayerFF(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
norm_x = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
y = self.DenseReluDense(norm_x, training=training) dense_output = self.DenseReluDense(normed_hidden_states, training=training)
layer_output = hidden_states + self.dropout(y, training=training) hidden_states = hidden_states + self.dropout(dense_output, training=training)
return layer_output return hidden_states
class TFT5Attention(tf.keras.layers.Layer): class TFT5Attention(tf.keras.layers.Layer):
NEW_ID = itertools.count() NEW_ID = itertools.count()
def __init__(self, config, has_relative_attention_bias=False, is_bidirectional=False, **kwargs): def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.is_bidirectional = is_bidirectional
self.layer_id = next(TFT5Attention.NEW_ID) self.layer_id = next(TFT5Attention.NEW_ID)
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.use_cache = config.use_cache self.use_cache = config.use_cache
@ -131,9 +130,9 @@ class TFT5Attention(tf.keras.layers.Layer):
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model self.d_model = config.d_model
self.d_kv = config.d_kv self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.inner_dim = self.n_heads * self.d_kv self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q") self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
@ -175,46 +174,48 @@ class TFT5Attention(tf.keras.layers.Layer):
Returns: Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
""" """
ret = 0 relative_buckets = 0
n = -relative_position # n = -relative_position
if bidirectional: if bidirectional:
num_buckets //= 2 num_buckets //= 2
ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets
n = tf.math.abs(n) relative_position = tf.math.abs(relative_position)
else: else:
n = tf.math.maximum(n, 0) relative_position = -tf.math.minimum(relative_position, 0)
# now n is in the range [0, inf) # now n is in the range [0, inf)
max_exact = num_buckets // 2 max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact) is_small = tf.math.less(relative_position, max_exact)
val_if_large = max_exact + tf.dtypes.cast( relative_position_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact) tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact)
/ math.log(max_distance / max_exact) / math.log(max_distance / max_exact)
* (num_buckets - max_exact), * (num_buckets - max_exact),
tf.int32, tf.int32,
) )
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1)
ret += tf.where(is_small, n, val_if_large) relative_buckets += tf.where(is_small, relative_position, relative_position_if_large)
return ret return relative_buckets
def compute_bias(self, qlen, klen): def compute_bias(self, query_length, key_length):
""" Compute binned relative position bias """ """ Compute binned relative position bias """
context_position = tf.range(qlen)[:, None] context_position = tf.range(query_length)[:, None]
memory_position = tf.range(klen)[None, :] memory_position = tf.range(key_length)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen) relative_position = memory_position - context_position # shape (query_length, key_length)
rp_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
relative_position, relative_position,
bidirectional=self.is_bidirectional, bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets, num_buckets=self.relative_attention_num_buckets,
) )
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]), axis=0
) # shape (1, num_heads, query_length, key_length)
return values return values
def call( def call(
self, self,
input, hidden_states,
mask=None, mask=None,
kv=None, key_value_states=None,
position_bias=None, position_bias=None,
past_key_value=None, past_key_value=None,
head_mask=None, head_mask=None,
@ -224,95 +225,108 @@ class TFT5Attention(tf.keras.layers.Layer):
output_attentions=False, output_attentions=False,
): ):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (bs, qlen, dim) # Input is (batch_size, query_length, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head) # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
bs, qlen, dim = shape_list(input) batch_size, seq_length = shape_list(hidden_states)[:2]
real_seq_length = seq_length
if past_key_value is not None: if past_key_value is not None:
assert self.is_decoder is True, "Encoder cannot cache past key value states"
assert ( assert (
len(past_key_value) == 2 len(past_key_value) == 2
), "past_key_value should have 2 past states: keys and values. Got {} past states".format( ), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value) len(past_key_value)
) )
real_qlen = qlen + shape_list(past_key_value[0])[2] if query_length is None else query_length real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
else:
real_qlen = qlen
if kv is None: key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
klen = real_qlen
else:
klen = shape_list(kv)[1]
def shape(x): def shape(hidden_states):
""" projection """ """ projection """
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, self.d_kv)), perm=(0, 2, 1, 3)) return tf.transpose(
tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3)
)
def unshape(x): def unshape(hidden_states):
""" compute context """ """ compute context """
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim)) return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim))
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) def project(hidden_states, proj_layer, key_value_states, past_key_value):
""" projects hidden states correctly to key/query states """
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if kv is None: if past_key_value is not None:
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) if key_value_states is None:
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) # self-attn
elif past_key_value is None: # (batch_size, n_heads, key_length, dim_per_head)
k = v = kv hidden_states = tf.concat([past_key_value, hidden_states], axis=2)
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) else:
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) # cross-attn
hidden_states = past_key_value
return hidden_states
if past_key_value is not None: # get query
if kv is None: query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head)
k_, v_ = past_key_value
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head) # get key/value
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) key_states = project(
else: hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
k, v = past_key_value )
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
# to cope with keras serialization # to cope with keras serialization
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True: if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
present_key_value_state = ((k, v),) present_key_value_state = (key_states, value_states)
else: else:
present_key_value_state = (None,) present_key_value_state = None
scores = tf.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) scores = tf.einsum(
"bnqd,bnkd->bnqk", query_states, key_states
) # (batch_size, n_heads, query_length, key_length)
if position_bias is None: if position_bias is None:
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias") position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32)
position_bias = self.compute_bias(real_qlen, klen) else:
position_bias = self.compute_bias(real_seq_length, key_length)
# if key and values are already calculated # if key and values are already calculated
# we want only the last query position bias # we want only the last query position bias
if past_key_value is not None: if past_key_value is not None:
position_bias = position_bias[:, :, -qlen:, :] position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (bs, n_heads, qlen, klen) position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
scores += position_bias scores += position_bias
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) weights = tf.nn.softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if head_mask is not None:
weights = weights * head_mask weights = weights * head_mask
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
context = self.o(context) attn_output = self.o(unshape(attn_output))
outputs = (context,) + present_key_value_state outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions: if output_attentions:
outputs = outputs + (weights,) outputs = outputs + (weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
return outputs return outputs
@ -322,7 +336,6 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
self.SelfAttention = TFT5Attention( self.SelfAttention = TFT5Attention(
config, config,
has_relative_attention_bias=has_relative_attention_bias, has_relative_attention_bias=has_relative_attention_bias,
is_bidirectional=not config.is_decoder,
name="SelfAttention", name="SelfAttention",
) )
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
@ -339,9 +352,9 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
output_attentions=False, output_attentions=False,
training=False, training=False,
): ):
norm_x = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
norm_x, normed_hidden_states,
mask=attention_mask, mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, head_mask=head_mask,
@ -350,19 +363,17 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
y = attention_output[0] hidden_states = hidden_states + self.dropout(attention_output[0], training=training)
layer_output = hidden_states + self.dropout(y, training=training) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs return outputs
class TFT5LayerCrossAttention(tf.keras.layers.Layer): class TFT5LayerCrossAttention(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.EncDecAttention = TFT5Attention( self.EncDecAttention = TFT5Attention(
config, config,
has_relative_attention_bias=has_relative_attention_bias, has_relative_attention_bias=False,
is_bidirectional=True,
name="EncDecAttention", name="EncDecAttention",
) )
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
@ -371,7 +382,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
def call( def call(
self, self,
hidden_states, hidden_states,
kv, key_value_states,
attention_mask=None, attention_mask=None,
position_bias=None, position_bias=None,
head_mask=None, head_mask=None,
@ -381,11 +392,11 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
output_attentions=False, output_attentions=False,
training=False, training=False,
): ):
norm_x = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
norm_x, normed_hidden_states,
mask=attention_mask, mask=attention_mask,
kv=kv, key_value_states=key_value_states,
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, head_mask=head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
@ -394,9 +405,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
y = attention_output[0] hidden_states = hidden_states + self.dropout(attention_output[0], training=training)
layer_output = hidden_states + self.dropout(y, training=training) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs return outputs
@ -416,7 +426,6 @@ class TFT5Block(tf.keras.layers.Layer):
self.layer.append( self.layer.append(
TFT5LayerCrossAttention( TFT5LayerCrossAttention(
config, config,
has_relative_attention_bias=has_relative_attention_bias,
name="layer_._1", name="layer_._1",
) )
) )
@ -477,7 +486,7 @@ class TFT5Block(tf.keras.layers.Layer):
cross_attention_outputs = self.layer[1]( cross_attention_outputs = self.layer[1](
hidden_states, hidden_states,
kv=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
head_mask=head_mask, head_mask=head_mask,
@ -731,17 +740,18 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, present_key_value_state = layer_outputs[:2]
if i == 0:
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # layer_outputs = hidden-states, past_key_values, (self-attention weights),
position_bias = layer_outputs[3 if output_attentions else 2] # (self-attention position bias), (cross-attention position bias), (cross-attention weights),
if self.is_decoder and encoder_hidden_states is not None: position_bias = layer_outputs[2]
encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3] if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states # append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],) all_attentions = all_attentions + (layer_outputs[3],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=training) hidden_states = self.dropout(hidden_states, training=training)

View File

@ -461,7 +461,7 @@ class ModelTesterMixin:
inputs = self._prepare_for_class(inputs_dict, model_class).copy() inputs = self._prepare_for_class(inputs_dict, model_class).copy()
inputs["head_mask"] = head_mask inputs["head_mask"] = head_mask
outputs = model(**inputs) outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation # Test that we can get a gradient back for importance score computation
output = sum(t.sum() for t in outputs[0]) output = sum(t.sum() for t in outputs[0])

View File

@ -556,6 +556,32 @@ class T5ModelIntegrationTests(unittest.TestCase):
def tokenizer(self): def tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base") return T5Tokenizer.from_pretrained("t5-base")
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = T5ForConditionalGeneration.from_pretrained("t5-small", return_dict=True).to(torch_device)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
mtf_score = -(labels.shape[-1] * loss.item())
EXPECTED_SCORE = -19.0845
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow @slow
def test_summarization(self): def test_summarization(self):
model = self.model model = self.model
@ -567,8 +593,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.' ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
expected_summaries = [ expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .', 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .", "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .", "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .', 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
] ]

View File

@ -311,8 +311,8 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.' ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
expected_summaries = [ expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .', 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .", "the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .", "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .', 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
] ]