mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
304 lines
13 KiB
Python
304 lines
13 KiB
Python
|
|
import os
|
|
import math
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers.configuration_albert import AlbertConfig
|
|
from transformers.modeling_bert import BertEmbeddings, BertModel, BertSelfAttention, prune_linear_layer, gelu_new
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|
""" Load tf checkpoints in a pytorch model."""
|
|
try:
|
|
import re
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
|
raise
|
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
|
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
|
# Load weights from TF model
|
|
init_vars = tf.train.list_variables(tf_path)
|
|
names = []
|
|
arrays = []
|
|
for name, shape in init_vars:
|
|
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
|
array = tf.train.load_variable(tf_path, name)
|
|
names.append(name)
|
|
arrays.append(array)
|
|
|
|
for name, array in zip(names, arrays):
|
|
print(name)
|
|
|
|
for name, array in zip(names, arrays):
|
|
print(name)
|
|
og = name
|
|
name = name.replace("ffn_1", "ffn")
|
|
name = name.replace("ffn/intermediate/output", "ffn_output")
|
|
name = name.replace("attention_1", "attention")
|
|
name = name.replace("cls/predictions/transform", "predictions")
|
|
name = name.replace("LayerNorm_1", "attention/LayerNorm")
|
|
name = name.replace("inner_group_", "albert_layers/")
|
|
name = name.replace("group_", "albert_layer_groups/")
|
|
name = name.split('/')
|
|
|
|
print(name)
|
|
pointer = model
|
|
for m_name in name:
|
|
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
|
l = re.split(r'_(\d+)', m_name)
|
|
else:
|
|
l = [m_name]
|
|
|
|
if l[0] == 'kernel' or l[0] == 'gamma':
|
|
pointer = getattr(pointer, 'weight')
|
|
elif l[0] == 'output_bias' or l[0] == 'beta':
|
|
pointer = getattr(pointer, 'bias')
|
|
elif l[0] == 'output_weights':
|
|
pointer = getattr(pointer, 'weight')
|
|
elif l[0] == 'squad':
|
|
pointer = getattr(pointer, 'classifier')
|
|
else:
|
|
try:
|
|
pointer = getattr(pointer, l[0])
|
|
except AttributeError:
|
|
logger.info("Skipping {}".format("/".join(name)))
|
|
continue
|
|
if len(l) >= 2:
|
|
num = int(l[1])
|
|
pointer = pointer[num]
|
|
|
|
if m_name[-11:] == '_embeddings':
|
|
pointer = getattr(pointer, 'weight')
|
|
elif m_name == 'kernel':
|
|
array = np.transpose(array)
|
|
print("transposed")
|
|
try:
|
|
assert pointer.shape == array.shape
|
|
except AssertionError as e:
|
|
e.args += (pointer.shape, array.shape)
|
|
raise
|
|
print("Initialize PyTorch weight {} from {}".format(name, og))
|
|
pointer.data = torch.from_numpy(array)
|
|
|
|
return model
|
|
|
|
|
|
class AlbertEmbeddings(BertEmbeddings):
|
|
"""
|
|
Construct the embeddings from word, position and token_type embeddings.
|
|
"""
|
|
def __init__(self, config):
|
|
super(AlbertEmbeddings, self).__init__(config)
|
|
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=0)
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
|
|
self.LayerNorm = torch.nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
|
|
|
|
|
class AlbertAttention(BertSelfAttention):
|
|
def __init__(self, config):
|
|
super(AlbertAttention, self).__init__(config)
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
self.hidden_size = config.hidden_size
|
|
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.pruned_heads = set()
|
|
|
|
def prune_heads(self, heads):
|
|
if len(heads) == 0:
|
|
return
|
|
mask = torch.ones(self.num_attention_heads, self.attention_head_size)
|
|
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
|
for head in heads:
|
|
# Compute how many pruned heads are before the head and move the index accordingly
|
|
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
|
mask[head] = 0
|
|
mask = mask.view(-1).contiguous().eq(1)
|
|
index = torch.arange(len(mask))[mask].long()
|
|
|
|
# Prune linear layers
|
|
self.query = prune_linear_layer(self.query, index)
|
|
self.key = prune_linear_layer(self.key, index)
|
|
self.value = prune_linear_layer(self.value, index)
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
|
|
|
# Update hyper params and store pruned heads
|
|
self.num_attention_heads = self.num_attention_heads - len(heads)
|
|
self.all_head_size = self.attention_head_size * self.num_attention_heads
|
|
self.pruned_heads = self.pruned_heads.union(heads)
|
|
|
|
def forward(self, input_ids, attention_mask=None, head_mask=None):
|
|
mixed_query_layer = self.query(input_ids)
|
|
mixed_key_layer = self.key(input_ids)
|
|
mixed_value_layer = self.value(input_ids)
|
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
value_layer = self.transpose_for_scores(mixed_value_layer)
|
|
|
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
if attention_mask is not None:
|
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
# Normalize the attention scores to probabilities.
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
|
|
|
# This is actually dropping out entire tokens to attend to, which might
|
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
|
attention_probs = self.dropout(attention_probs)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
context_layer = torch.matmul(attention_probs, value_layer)
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
|
reshaped_context_layer = context_layer.view(*new_context_layer_shape)
|
|
|
|
|
|
# Should find a better way to do this
|
|
w = self.dense.weight.T.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
|
|
b = self.dense.bias
|
|
|
|
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
|
|
projected_context_layer = self.dropout(projected_context_layer)
|
|
layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer)
|
|
return layernormed_context_layer, projected_context_layer, reshaped_context_layer, context_layer, attention_scores, attention_probs, attention_mask
|
|
|
|
|
|
class AlbertLayer(nn.Module):
|
|
def __init__(self, config):
|
|
super(AlbertLayer, self).__init__()
|
|
|
|
self.config = config
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.attention = AlbertAttention(config)
|
|
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
|
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
|
attention_output = self.attention(hidden_states, attention_mask)[0]
|
|
ffn_output = self.ffn(attention_output)
|
|
ffn_output = gelu_new(ffn_output)
|
|
ffn_output = self.ffn_output(ffn_output)
|
|
hidden_states = self.LayerNorm(ffn_output + attention_output)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class AlbertLayerGroup(nn.Module):
|
|
def __init__(self, config):
|
|
super(AlbertLayerGroup, self).__init__()
|
|
|
|
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
|
|
|
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
|
for albert_layer in self.albert_layers:
|
|
hidden_states = albert_layer(hidden_states, attention_mask, head_mask)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class AlbertTransformer(nn.Module):
|
|
def __init__(self, config):
|
|
super(AlbertTransformer, self).__init__()
|
|
|
|
self.config = config
|
|
self.output_attentions = config.output_attentions
|
|
self.output_hidden_states = config.output_hidden_states
|
|
self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
|
|
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
|
|
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
|
|
|
for layer_idx in range(self.config.num_hidden_layers):
|
|
group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups)
|
|
hidden_states = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask)
|
|
|
|
return (hidden_states,)
|
|
|
|
|
|
class AlbertModel(BertModel):
|
|
def __init__(self, config):
|
|
super(AlbertModel, self).__init__(config)
|
|
|
|
self.config = config
|
|
self.embeddings = AlbertEmbeddings(config)
|
|
self.encoder = AlbertTransformer(config)
|
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
|
self.pooler_activation = nn.Tanh()
|
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
if token_type_ids is None:
|
|
token_type_ids = torch.zeros_like(input_ids)
|
|
|
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
if head_mask is not None:
|
|
if head_mask.dim() == 1:
|
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
elif head_mask.dim() == 2:
|
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
|
else:
|
|
head_mask = [None] * self.config.num_hidden_layers
|
|
|
|
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
|
encoder_outputs = self.encoder(embedding_output,
|
|
extended_attention_mask,
|
|
head_mask=head_mask)
|
|
sequence_output = encoder_outputs[0]
|
|
|
|
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
|
|
|
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
|
return outputs
|
|
|
|
|
|
class AlbertForMaskedLM(nn.Module):
|
|
def __init__(self, config):
|
|
super(AlbertForMaskedLM, self).__init__()
|
|
|
|
self.config = config
|
|
self.bert = AlbertModel(config)
|
|
self.LayerNorm = nn.LayerNorm(config.embedding_size)
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
|
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
|
self.word_embeddings = nn.Linear(config.embedding_size, config.vocab_size)
|
|
|
|
def tie_weights(self):
|
|
""" Make sure we are sharing the input and output embeddings.
|
|
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
|
"""
|
|
self._tie_or_clone_weights(self.classifier.word_embeddings,
|
|
self.transformer.embeddings.word_embeddings)
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
|
hidden_states = self.bert(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None)[0]
|
|
hidden_states = self.dense(hidden_states)
|
|
hidden_states = gelu_new(hidden_states)
|
|
hidden_states = self.LayerNorm(hidden_states)
|
|
|
|
logits = self.word_embeddings(hidden_states)
|
|
|
|
return logits
|