[T5] Bug correction & Refactor (#8518)

* fix bug

* T5 refactor

* refactor tf

* apply sylvains suggestions
This commit is contained in:
Patrick von Platen 2020-11-13 16:57:31 +01:00 committed by GitHub
parent 42f63e3871
commit 42e2d02e44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 310 additions and 262 deletions

View File

@ -75,7 +75,6 @@ class T5Config(PretrainedConfig):
def __init__(
self,
vocab_size=32128,
n_positions=512,
d_model=512,
d_kv=64,
d_ff=2048,
@ -98,7 +97,6 @@ class T5Config(PretrainedConfig):
**kwargs,
)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
@ -112,10 +110,6 @@ class T5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
@property
def max_position_embeddings(self):
return self.n_positions
@property
def hidden_size(self):
return self.d_model

View File

@ -17,8 +17,6 @@
import argparse
import torch
from transformers import T5Config, T5Model, load_tf_weights_in_t5
from transformers.utils import logging
@ -37,7 +35,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
model.save_pretrained(pytorch_dump_path)
if __name__ == "__main__":

View File

@ -115,12 +115,12 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
scope_names = [m_name]
if scope_names[0] in ["kernel", "scale", "embedding"]:
pointer = getattr(pointer, "weight")
# elif scope_names[0] == 'scale':
# pointer = getattr(pointer, 'weight')
# elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta':
# pointer = getattr(pointer, 'bias')
# elif scope_names[0] == 'squad':
# pointer = getattr(pointer, 'classifier')
elif scope_names[0] == "scale":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
@ -147,7 +147,6 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
tf_weights.pop(txt_name, None)
logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
# logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
return model
@ -167,14 +166,15 @@ class T5LayerNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
def forward(self, hidden_states):
# layer norm should always be calculated in float32
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x / torch.sqrt(variance + self.variance_epsilon)
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into float16 if necessary
if self.weight.dtype == torch.float16:
x = x.to(torch.float16)
return self.weight * x
hidden_states = hidden_states.to(torch.float16)
return self.weight * hidden_states
class T5DenseReluDense(nn.Module):
@ -185,11 +185,11 @@ class T5DenseReluDense(nn.Module):
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
h = self.wi(hidden_states)
h = F.relu(h)
h = self.dropout(h)
h = self.wo(h)
return h
hidden_states = self.wi(hidden_states)
hidden_states = F.relu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerFF(nn.Module):
@ -200,25 +200,24 @@ class T5LayerFF(nn.Module):
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
norm_x = self.layer_norm(hidden_states)
y = self.DenseReluDense(norm_x)
layer_output = hidden_states + self.dropout(y)
return layer_output
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False, is_bidirectional=False):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
super().__init__()
self.is_bidirectional = is_bidirectional
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model
self.d_kv = config.d_kv
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.d_kv
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@ -233,7 +232,9 @@ class T5Attention(nn.Module):
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads)
heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
)
# Prune linear layers
self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index)
@ -241,7 +242,7 @@ class T5Attention(nn.Module):
self.o = prune_linear_layer(self.o, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.inner_dim = self.d_kv * self.n_heads
self.inner_dim = self.key_value_proj_dim * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
@staticmethod
@ -266,49 +267,52 @@ class T5Attention(nn.Module):
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
relative_buckets = 0
if bidirectional:
num_buckets //= 2
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
n = torch.abs(n)
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
n = torch.max(n, torch.zeros_like(n))
# now n is in the range [0, inf)
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = n < max_exact
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
ret += torch.where(is_small, n, val_if_large)
return ret
relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
return relative_buckets
def compute_bias(self, qlen, klen):
def compute_bias(self, query_length, key_length):
""" Compute binned relative position bias """
context_position = torch.arange(qlen, dtype=torch.long)[:, None]
memory_position = torch.arange(klen, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
bidirectional=self.is_bidirectional,
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
)
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def forward(
self,
input,
hidden_states,
mask=None,
kv=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
head_mask=None,
@ -317,106 +321,113 @@ class T5Attention(nn.Module):
output_attentions=False,
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
# past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head)
bs, qlen, dim = input.size()
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
if past_key_value is not None:
assert self.is_decoder is True, "Encoder cannot cache past key value states"
assert (
len(past_key_value) == 2
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value)
)
real_qlen = qlen + past_key_value[0].shape[2] if query_length is None else query_length
else:
real_qlen = qlen
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
if kv is None:
klen = real_qlen
else:
klen = kv.size(1)
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(x):
def shape(states):
""" projection """
return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2)
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(x):
""" compute context """
return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim)
def unshape(states):
""" reshape """
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
""" projects hidden states correctly to key/query states """
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if kv is None:
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
elif past_key_value is None:
k = v = kv
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
if past_key_value is not None:
if kv is None:
k_, v_ = past_key_value
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = past_key_value
# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
if self.is_decoder and use_cache is True:
present_key_value_state = ((k, v),)
else:
present_key_value_state = (None,)
# get key/value states
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
# (bs, n_heads, qlen, klen)
# compute scores
scores = torch.matmul(
q, k.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", q, k), compatible with onnx op>9
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None:
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(real_qlen, klen)
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
else:
position_bias = self.compute_bias(real_seq_length, key_length)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -qlen:, :]
position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None:
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
scores += position_bias
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
attn_weights = F.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = F.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
attn_weights = attn_weights * head_mask
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
context = self.o(context)
outputs = (context,) + present_key_value_state
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
outputs = outputs + (weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
outputs = outputs + (attn_weights,)
return outputs
class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = T5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=not config.is_decoder
)
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
@ -430,9 +441,9 @@ class T5LayerSelfAttention(nn.Module):
use_cache=False,
output_attentions=False,
):
norm_x = self.layer_norm(hidden_states)
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
norm_x,
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
@ -440,25 +451,22 @@ class T5LayerSelfAttention(nn.Module):
use_cache=use_cache,
output_attentions=output_attentions,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
class T5LayerCrossAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
def __init__(self, config):
super().__init__()
self.EncDecAttention = T5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=True
)
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
kv,
key_value_states,
attention_mask=None,
position_bias=None,
head_mask=None,
@ -467,11 +475,11 @@ class T5LayerCrossAttention(nn.Module):
query_length=None,
output_attentions=False,
):
norm_x = self.layer_norm(hidden_states)
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
norm_x,
normed_hidden_states,
mask=attention_mask,
kv=kv,
key_value_states=key_value_states,
position_bias=position_bias,
head_mask=head_mask,
past_key_value=past_key_value,
@ -479,8 +487,7 @@ class T5LayerCrossAttention(nn.Module):
query_length=query_length,
output_attentions=output_attentions,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs
@ -492,7 +499,7 @@ class T5Block(nn.Module):
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder:
self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias))
self.layer.append(T5LayerCrossAttention(config))
self.layer.append(T5LayerFF(config))
@ -550,7 +557,7 @@ class T5Block(nn.Module):
cross_attention_outputs = self.layer[1](
hidden_states,
kv=encoder_hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask,
@ -619,12 +626,12 @@ class T5PreTrainedModel(PreTrainedModel):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
d_model = self.config.d_model
d_kv = self.config.d_kv
key_value_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * d_kv) ** -0.5))
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * d_kv) ** -0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
@ -775,20 +782,20 @@ class T5Stack(T5PreTrainedModel):
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]
if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights),
# (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],)
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[4 if i == 0 else 3],)
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
@ -920,6 +927,12 @@ T5_INPUTS_DOCSTRING = r"""
T5_START_DOCSTRING,
)
class T5Model(T5PreTrainedModel):
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
def __init__(self, config: T5Config):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)
@ -1063,7 +1076,14 @@ class T5Model(T5PreTrainedModel):
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
class T5ForConditionalGeneration(T5PreTrainedModel):
authorized_missing_keys = [r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight"]
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
def __init__(self, config):
super().__init__(config)

View File

@ -81,10 +81,10 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
super().build(input_shape)
def call(self, x):
variance = tf.math.reduce_mean(tf.math.square(x), axis=-1, keepdims=True)
x = x * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * x
def call(self, hidden_states):
variance = tf.math.reduce_mean(tf.math.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * tf.math.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class TFT5DenseReluDense(tf.keras.layers.Layer):
@ -96,11 +96,11 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
self.act = tf.keras.activations.relu
def call(self, hidden_states, training=False):
h = self.wi(hidden_states)
h = self.act(h)
h = self.dropout(h, training=training)
h = self.wo(h)
return h
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.wo(hidden_states)
return hidden_states
class TFT5LayerFF(tf.keras.layers.Layer):
@ -111,18 +111,17 @@ class TFT5LayerFF(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(self, hidden_states, training=False):
norm_x = self.layer_norm(hidden_states)
y = self.DenseReluDense(norm_x, training=training)
layer_output = hidden_states + self.dropout(y, training=training)
return layer_output
normed_hidden_states = self.layer_norm(hidden_states)
dense_output = self.DenseReluDense(normed_hidden_states, training=training)
hidden_states = hidden_states + self.dropout(dense_output, training=training)
return hidden_states
class TFT5Attention(tf.keras.layers.Layer):
NEW_ID = itertools.count()
def __init__(self, config, has_relative_attention_bias=False, is_bidirectional=False, **kwargs):
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs)
self.is_bidirectional = is_bidirectional
self.layer_id = next(TFT5Attention.NEW_ID)
self.is_decoder = config.is_decoder
self.use_cache = config.use_cache
@ -131,9 +130,9 @@ class TFT5Attention(tf.keras.layers.Layer):
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.d_model = config.d_model
self.d_kv = config.d_kv
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.inner_dim = self.n_heads * self.d_kv
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name="q")
@ -175,46 +174,48 @@ class TFT5Attention(tf.keras.layers.Layer):
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
relative_buckets = 0
# n = -relative_position
if bidirectional:
num_buckets //= 2
ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets
n = tf.math.abs(n)
relative_buckets += tf.dtypes.cast(tf.math.greater(relative_position, 0), tf.int32) * num_buckets
relative_position = tf.math.abs(relative_position)
else:
n = tf.math.maximum(n, 0)
relative_position = -tf.math.minimum(relative_position, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(n, max_exact)
val_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact)
is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.dtypes.cast(
tf.math.log(tf.dtypes.cast(relative_position, tf.float32) / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact),
tf.int32,
)
val_if_large = tf.math.minimum(val_if_large, num_buckets - 1)
ret += tf.where(is_small, n, val_if_large)
return ret
relative_position_if_large = tf.math.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += tf.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, qlen, klen):
def compute_bias(self, query_length, key_length):
""" Compute binned relative position bias """
context_position = tf.range(qlen)[:, None]
memory_position = tf.range(klen)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
context_position = tf.range(query_length)[:, None]
memory_position = tf.range(key_length)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=self.is_bidirectional,
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]), axis=0
) # shape (1, num_heads, query_length, key_length)
return values
def call(
self,
input,
hidden_states,
mask=None,
kv=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
head_mask=None,
@ -224,95 +225,108 @@ class TFT5Attention(tf.keras.layers.Layer):
output_attentions=False,
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
# past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head)
bs, qlen, dim = shape_list(input)
# Input is (batch_size, query_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = shape_list(hidden_states)[:2]
real_seq_length = seq_length
if past_key_value is not None:
assert self.is_decoder is True, "Encoder cannot cache past key value states"
assert (
len(past_key_value) == 2
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value)
)
real_qlen = qlen + shape_list(past_key_value[0])[2] if query_length is None else query_length
else:
real_qlen = qlen
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
if kv is None:
klen = real_qlen
else:
klen = shape_list(kv)[1]
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(x):
def shape(hidden_states):
""" projection """
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, self.d_kv)), perm=(0, 2, 1, 3))
return tf.transpose(
tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3)
)
def unshape(x):
def unshape(hidden_states):
""" compute context """
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim))
return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim))
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
""" projects hidden states correctly to key/query states """
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if kv is None:
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
elif past_key_value is None:
k = v = kv
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = tf.concat([past_key_value, hidden_states], axis=2)
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
if past_key_value is not None:
if kv is None:
k_, v_ = past_key_value
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = past_key_value
# get query
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, query_length, dim_per_head)
# get key/value
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
# to cope with keras serialization
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
present_key_value_state = ((k, v),)
present_key_value_state = (key_states, value_states)
else:
present_key_value_state = (None,)
present_key_value_state = None
scores = tf.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
scores = tf.einsum(
"bnqd,bnkd->bnqk", query_states, key_states
) # (batch_size, n_heads, query_length, key_length)
if position_bias is None:
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(real_qlen, klen)
position_bias = tf.zeros((1, self.n_heads, real_seq_length, key_length), dtype=tf.float32)
else:
position_bias = self.compute_bias(real_seq_length, key_length)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -qlen:, :]
position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None:
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
position_bias = position_bias + mask # (batch_size, n_heads, query_length, key_length)
scores += position_bias
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
weights = tf.nn.softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length)
weights = self.dropout(weights, training=training) # (batch_size, n_heads, query_length, key_length)
# Mask heads if we want to
if head_mask is not None:
weights = weights * head_mask
context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
context = unshape(context) # (bs, qlen, dim)
attn_output = tf.matmul(weights, value_states) # (batch_size, n_heads, query_length, dim_per_head)
context = self.o(context)
attn_output = self.o(unshape(attn_output))
outputs = (context,) + present_key_value_state
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
outputs = outputs + (weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
return outputs
@ -322,7 +336,6 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
self.SelfAttention = TFT5Attention(
config,
has_relative_attention_bias=has_relative_attention_bias,
is_bidirectional=not config.is_decoder,
name="SelfAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
@ -339,9 +352,9 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
output_attentions=False,
training=False,
):
norm_x = self.layer_norm(hidden_states)
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
norm_x,
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
@ -350,19 +363,17 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
output_attentions=output_attentions,
training=training,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y, training=training)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
hidden_states = hidden_states + self.dropout(attention_output[0], training=training)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
class TFT5LayerCrossAttention(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.EncDecAttention = TFT5Attention(
config,
has_relative_attention_bias=has_relative_attention_bias,
is_bidirectional=True,
has_relative_attention_bias=False,
name="EncDecAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
@ -371,7 +382,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
def call(
self,
hidden_states,
kv,
key_value_states,
attention_mask=None,
position_bias=None,
head_mask=None,
@ -381,11 +392,11 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
output_attentions=False,
training=False,
):
norm_x = self.layer_norm(hidden_states)
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
norm_x,
normed_hidden_states,
mask=attention_mask,
kv=kv,
key_value_states=key_value_states,
position_bias=position_bias,
head_mask=head_mask,
past_key_value=past_key_value,
@ -394,9 +405,8 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
output_attentions=output_attentions,
training=training,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y, training=training)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
hidden_states = hidden_states + self.dropout(attention_output[0], training=training)
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
@ -416,7 +426,6 @@ class TFT5Block(tf.keras.layers.Layer):
self.layer.append(
TFT5LayerCrossAttention(
config,
has_relative_attention_bias=has_relative_attention_bias,
name="layer_._1",
)
)
@ -477,7 +486,7 @@ class TFT5Block(tf.keras.layers.Layer):
cross_attention_outputs = self.layer[1](
hidden_states,
kv=encoder_hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask,
@ -731,17 +740,18 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]
if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, past_key_values, (self-attention weights),
# (self-attention position bias), (cross-attention position bias), (cross-attention weights),
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],)
all_attentions = all_attentions + (layer_outputs[3],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)

View File

@ -461,7 +461,7 @@ class ModelTesterMixin:
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
inputs["head_mask"] = head_mask
outputs = model(**inputs)
outputs = model(**inputs, return_dict=True)
# Test that we can get a gradient back for importance score computation
output = sum(t.sum() for t in outputs[0])

View File

@ -556,6 +556,32 @@ class T5ModelIntegrationTests(unittest.TestCase):
def tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = T5ForConditionalGeneration.from_pretrained("t5-small", return_dict=True).to(torch_device)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
mtf_score = -(labels.shape[-1] * loss.item())
EXPECTED_SCORE = -19.0845
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
def test_summarization(self):
model = self.model
@ -567,8 +593,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .',
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .",
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
]

View File

@ -311,8 +311,8 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
expected_summaries = [
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video at the crash site . "one can hear cries of \'My God\' in several languages," one magazine says .',
"the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .",
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," one magazine says .',
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a preliminary examination into the situation in the occupied Palestinian territory . as members of the court, Palestinians may be subject to counter-charges as well .",
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and implement a rigorous inspection regime .",
'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
]