mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
clean up model
This commit is contained in:
parent
6cc651778a
commit
d0cb9fa2a7
163
modeling.py
163
modeling.py
@ -27,26 +27,28 @@ import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
# For information: OpenAI GPT gelu version is a bit different:
|
||||
# 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
class BertConfig(object):
|
||||
"""Configuration for `BertModel`."""
|
||||
|
||||
"""Configuration class to store the configuration of a `BertModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02):
|
||||
vocab_size,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
initializer_range=0.02):
|
||||
"""Constructs BertConfig.
|
||||
|
||||
Args:
|
||||
@ -110,42 +112,31 @@ class BertConfig(object):
|
||||
|
||||
class BERTLayerNorm(nn.Module):
|
||||
def __init__(self, config, variance_epsilon=1e-12):
|
||||
"Construct a layernorm module in the TF style (epsilon inside the square root)."
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||||
"""
|
||||
super(BERTLayerNorm, self).__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(config.hidden_size))
|
||||
self.beta = nn.Parameter(torch.zeros(config.hidden_size))
|
||||
self.variance_epsilon = variance_epsilon
|
||||
|
||||
def forward(self, x):
|
||||
# TODO check it's identical to TF implementation in details (epsilon and axes)
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.gamma * x + self.beta
|
||||
# tf.contrib.layers.layer_norm(
|
||||
# inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)
|
||||
|
||||
class BERTEmbeddings(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BERTEmbeddings, self).__init__()
|
||||
"""Construct the embedding module from word, position and token_type embeddings.
|
||||
"""
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
# Position embeddings are (normally) a contiguous range so we could use a slice
|
||||
# Since the position embedding table is a learned variable, we create it
|
||||
# using a (long) sequence length `max_position_embeddings`. The actual
|
||||
# sequence length might be shorter than this, for faster training of
|
||||
# tasks that do not have long sequences.
|
||||
#
|
||||
# So `full_position_embeddings` is effectively an embedding table
|
||||
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
|
||||
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
|
||||
# perform a slice.
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# token_type_embeddings vocabulary is very small. TF used one-hot embeddings to speedup.
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
self.LayerNorm = BERTLayerNorm(config) # Not snake-cased to stick with TF model variable name
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = BERTLayerNorm(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None):
|
||||
@ -182,65 +173,37 @@ class BERTSelfAttention(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, is_key_tensor=False):
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
if is_key_tensor:
|
||||
return x.permute(0, 2, 3, 1)
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
# Scalar dimensions referenced here:
|
||||
# B = batch size (number of sequences)
|
||||
# F = `from_tensor` sequence length
|
||||
# T = `to_tensor` sequence length
|
||||
# N = `num_attention_heads`
|
||||
# H = `size_per_head`
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer) #, is_key_tensor=True)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw
|
||||
# attention scores.
|
||||
# `attention_scores` = [B, N, F, T]
|
||||
attention_scores_no_norm = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores_no_mask = attention_scores_no_norm / math.sqrt(self.attention_head_size)
|
||||
|
||||
# TODO clean up this (precompute)
|
||||
# MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
|
||||
# `attention_mask` = [B, 1, F, T]
|
||||
# attention_mask = tf.expand_dims(attention_mask, axis=[1])
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# adder = (1.0 - attention_mask) * -10000.0
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_scores = attention_scores_no_mask + attention_mask
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
# `attention_probs` = [B, N, F, T]
|
||||
attention_probs_no_drop = nn.Softmax(dim=-1)(attention_scores)
|
||||
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_probs = self.dropout(attention_probs_no_drop)
|
||||
attention_probs = self.dropout(attention_probs)
|
||||
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
# aux_attention = attention_probs[0, 0, 0, :].view(1, 128, 1)
|
||||
# aux_attention = aux_attention.permute(0, 2, 1, 3).contiguous().view(1, 128, 768)
|
||||
# aux_attention = key_layer.permute(0, 2, 3, 1).contiguous().view(1, 128, 768)
|
||||
# aux_attention = key_layer.permute(0, 2, 1, 3).contiguous().view(1, 128, 768)
|
||||
|
||||
return context_layer
|
||||
|
||||
|
||||
@ -317,12 +280,6 @@ class BERTEncoder(nn.Module):
|
||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||
Return:
|
||||
float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||
"""
|
||||
all_encoder_layers = []
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, attention_mask)
|
||||
@ -337,14 +294,8 @@ class BERTPooler(nn.Module):
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: float Tensor of shape [batch_size, seq_length, hidden_size]
|
||||
Return:
|
||||
float Tensor of shape [batch_size, hidden_size]
|
||||
"""
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token. We assume that this has been pre-trained
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
@ -373,10 +324,6 @@ class BertModel(nn.Module):
|
||||
|
||||
Args:
|
||||
config: `BertConfig` instance.
|
||||
|
||||
Raises:
|
||||
ValueError: The config is invalid or one of the input tensor shapes
|
||||
is invalid.
|
||||
"""
|
||||
super(BertModel, self).__init__()
|
||||
self.embeddings = BERTEmbeddings(config)
|
||||
@ -384,26 +331,30 @@ class BertModel(nn.Module):
|
||||
self.pooler = BERTPooler(config)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
|
||||
# We create 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, from_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
|
||||
# It's more simple than the triangular masking of causal attention, just need to
|
||||
# prepare the broadcast here
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, from_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.float()
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(input_ids, token_type_ids)
|
||||
all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
# TODO DEbugging
|
||||
# all_encoder_layers = [attention_mask, embeddings_sum, embedding_output] + all_encoder_layers
|
||||
return all_encoder_layers, pooled_output
|
||||
|
||||
class BertForSequenceClassification(nn.Module):
|
||||
@ -435,9 +386,14 @@ class BertForSequenceClassification(nn.Module):
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
# Slight difference here with the TF version which uses truncated_normal
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(config.initializer_range)
|
||||
elif isinstance(m, BERTLayerNorm):
|
||||
m.beta.data.normal_(config.initializer_range)
|
||||
m.gamma.data.normal_(config.initializer_range)
|
||||
if isinstance(m, nn.Linear):
|
||||
m.bias.data.zero_()
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
|
||||
@ -474,13 +430,13 @@ class BertForQuestionAnswering(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertForQuestionAnswering, self).__init__()
|
||||
self.bert = BertModel(config)
|
||||
# TODO check if it's normal there is no dropout on SQuAD in the TF version
|
||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
# Slight difference here with the TF version which uses truncated_normal for initialization
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(config.initializer_range)
|
||||
elif isinstance(m, BERTLayerNorm):
|
||||
@ -497,20 +453,17 @@ class BertForQuestionAnswering(nn.Module):
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
#loss_fct = CrossEntropyLoss()
|
||||
#start_loss = loss_fct(start_logits, start_positions)
|
||||
#end_loss = loss_fct(end_logits, end_positions)
|
||||
batch_size, seq_length = input_ids.size()
|
||||
|
||||
|
||||
def compute_loss(logits, positions):
|
||||
max_position = positions.max().item()
|
||||
one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1).zero_()
|
||||
one_hot = one_hot.scatter_(1, positions.cpu(), 1) # Second argument need to be LongTensor and not cuda.LongTensor
|
||||
one_hot = one_hot.scatter_(1, positions.cpu(), 1) # Do this on CPU
|
||||
one_hot = one_hot[:, :seq_length].to(input_ids.device)
|
||||
log_probs = nn.functional.log_softmax(logits, dim = -1).view(batch_size, seq_length)
|
||||
loss = -torch.mean(torch.sum(one_hot*log_probs), dim = -1)
|
||||
return loss
|
||||
|
||||
|
||||
start_loss = compute_loss(start_logits, start_positions)
|
||||
end_loss = compute_loss(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
Loading…
Reference in New Issue
Block a user