updating T5

This commit is contained in:
thomwolf 2019-12-09 16:25:33 +01:00
parent f3776df0f3
commit 169fea6855

View File

@ -281,7 +281,7 @@ class T5Attention(nn.Module):
context_position = torch.arange(qlen, dtype=torch.long)[:, None]
memory_position = torch.arange(klen, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(relative_position,
rp_bucket = self._relative_position_bucket(relative_position, # shape (qlen, klen)
bidirectional=not self.is_decoder,
num_buckets=self.relative_attention_num_buckets)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
@ -337,14 +337,10 @@ class T5Attention(nn.Module):
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(qlen, klen)
if mask is not None:
position_bias += mask # (bs, n_heads, qlen, klen)
scores += position_bias
special_out = position_bias
if mask is not None:
scores += mask
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
@ -362,7 +358,7 @@ class T5Attention(nn.Module):
outputs = outputs + (weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
return outputs + (special_out,)
return outputs
class T5LayerSelfAttention(nn.Module):
@ -379,11 +375,9 @@ class T5LayerSelfAttention(nn.Module):
position_bias=position_bias,
head_mask=head_mask)
y = attention_output[0]
special_out = attention_output[-1]
attention_output = attention_output[:-1]
layer_output = hidden_states + self.dropout(y)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs + (special_out,)
return outputs
class T5LayerCrossAttention(nn.Module):
@ -426,8 +420,7 @@ class T5Block(nn.Module):
position_bias=position_bias,
head_mask=head_mask)
hidden_states = self_attention_outputs[0]
special_out = self_attention_outputs[-1]
outputs = self_attention_outputs[1:-1] # Keep self-attention outputs and relative position weights
outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
if not self.is_decoder:
hidden_states = self.layer[1](hidden_states)
@ -442,7 +435,7 @@ class T5Block(nn.Module):
hidden_states = self.layer[2](hidden_states)
outputs = (hidden_states,) + outputs # add attentions if we output them
return outputs + (special_out,) # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class T5PreTrainedModel(PreTrainedModel):
@ -536,6 +529,10 @@ class T5Stack(T5PreTrainedModel):
# positions we want to attend and -1e9 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# T5 has a mask that can compare sequence ids, we simulate this here with this transposistion
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
extended_attention_mask = (extended_attention_mask == extended_attention_mask.transpose(-1, -2))
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
@ -584,8 +581,6 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i])
if i == 0:
special_out = layer_outputs[-1]
# layer_outputs is a tuple with:
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states = layer_outputs[0]
@ -610,7 +605,7 @@ class T5Stack(T5PreTrainedModel):
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs + (special_out,) # last-layer hidden state, (all hidden states), (all attentions)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
T5_START_DOCSTRING = r""" The T5 model was proposed in