WIP debugging

This commit is contained in:
thomwolf 2019-12-02 15:47:00 +01:00
parent 268d4f2099
commit f3776df0f3

View File

@ -132,6 +132,21 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
""" Construct a layernorm module in the T5 style
No bias and no substraction of mean.
"""
super(T5LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x / torch.sqrt(variance + self.variance_epsilon)
return self.weight * x
class T5DenseReluDense(nn.Module):
def __init__(self, config):
super(T5DenseReluDense, self).__init__()
@ -151,7 +166,7 @@ class T5LayerFF(nn.Module):
def __init__(self, config):
super(T5LayerFF, self).__init__()
self.DenseReluDense = T5DenseReluDense(config)
self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
@ -316,13 +331,14 @@ class T5Attention(nn.Module):
cache[self.layer_id] = (k, v)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
scores = torch.einsum('bnqd,bnkd->bnqk', q, k) # (bs, n_heads, qlen, klen)
if position_bias is None:
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(qlen, klen)
scores += position_bias
special_out = position_bias
if mask is not None:
scores += mask
@ -346,14 +362,14 @@ class T5Attention(nn.Module):
outputs = outputs + (weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
return outputs
return outputs + (special_out,)
class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super(T5LayerSelfAttention, self).__init__()
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None):
@ -363,16 +379,18 @@ class T5LayerSelfAttention(nn.Module):
position_bias=position_bias,
head_mask=head_mask)
y = attention_output[0]
special_out = attention_output[-1]
attention_output = attention_output[:-1]
layer_output = hidden_states + self.dropout(y)
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs
return outputs + (special_out,)
class T5LayerCrossAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super(T5LayerCrossAttention, self).__init__()
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None):
@ -408,7 +426,8 @@ class T5Block(nn.Module):
position_bias=position_bias,
head_mask=head_mask)
hidden_states = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
special_out = self_attention_outputs[-1]
outputs = self_attention_outputs[1:-1] # Keep self-attention outputs and relative position weights
if not self.is_decoder:
hidden_states = self.layer[1](hidden_states)
@ -423,7 +442,7 @@ class T5Block(nn.Module):
hidden_states = self.layer[2](hidden_states)
outputs = (hidden_states,) + outputs # add attentions if we output them
return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
return outputs + (special_out,) # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class T5PreTrainedModel(PreTrainedModel):
@ -438,8 +457,7 @@ class T5PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
""" Initialize the weights """
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
if isinstance(module, T5LayerNorm):
module.weight.data.fill_(factor*1.0)
elif isinstance(module, (T5Model, T5WithLMHeadModel)):
# Mesh TensorFlow embeddings initialization
@ -478,7 +496,7 @@ class T5Stack(T5PreTrainedModel):
self.block = nn.ModuleList([T5Block(config, has_relative_attention_bias=bool(i == 0))
for i in range(config.num_layers)])
self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.init_weights()
@ -515,11 +533,11 @@ class T5Stack(T5PreTrainedModel):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# positions we want to attend and -1e9 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder:
# If a 2D ou 3D attention mask is provided for the cross-attention
@ -530,7 +548,7 @@ class T5Stack(T5PreTrainedModel):
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
encoder_extended_attention_mask = None
@ -553,6 +571,8 @@ class T5Stack(T5PreTrainedModel):
all_attentions = ()
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(hidden_states)
for i, layer_module in enumerate(self.block):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -564,6 +584,8 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i])
if i == 0:
special_out = layer_outputs[-1]
# layer_outputs is a tuple with:
# hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states = layer_outputs[0]
@ -588,7 +610,7 @@ class T5Stack(T5PreTrainedModel):
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
return outputs + (special_out,) # last-layer hidden state, (all hidden states), (all attentions)
T5_START_DOCSTRING = r""" The T5 model was proposed in
@ -707,9 +729,16 @@ class T5Model(T5PreTrainedModel):
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
encoder_attention_mask = kwargs_encoder.get("attention_mask", None)
if encoder_hidden_states is None:
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
if encoder_attention_mask is not None:
# Apply masking
encoder_attention_mask = (encoder_attention_mask != 0).to(hidden_states)
hidden_states = hidden_states * encoder_attention_mask.unsqueeze(-1)
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
@ -719,7 +748,7 @@ class T5Model(T5PreTrainedModel):
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
hidden_states = self.shared(decoder_inputs_ids) # Convert inputs in embeddings
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
kwargs_decoder["encoder_attention_mask"] = encoder_attention_mask
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
return decoder_outputs + encoder_outputs