mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 03:58:25 +06:00

This is the result of: $ black --line-length 119 examples templates transformers utils hubconf.py setup.py There's a lot of fairly long lines in the project. As a consequence, I'm picking the longest widely accepted line length, 119 characters. This is also Thomas' preference, because it allows for explicit variable names, to make the code easier to understand.
1034 lines
38 KiB
Python
1034 lines
38 KiB
Python
# MIT License
|
|
|
|
# Copyright (c) 2019 Yang Liu and the HuggingFace team
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
import copy
|
|
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn.init import xavier_uniform_
|
|
|
|
from transformers import BertModel, BertConfig, PreTrainedModel
|
|
|
|
from configuration_bertabs import BertAbsConfig
|
|
|
|
|
|
MAX_SIZE = 5000
|
|
|
|
BERTABS_FINETUNED_MODEL_MAP = {
|
|
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin",
|
|
}
|
|
|
|
|
|
class BertAbsPreTrainedModel(PreTrainedModel):
|
|
config_class = BertAbsConfig
|
|
pretrained_model_archive_map = BERTABS_FINETUNED_MODEL_MAP
|
|
load_tf_weights = False
|
|
base_model_prefix = "bert"
|
|
|
|
|
|
class BertAbs(BertAbsPreTrainedModel):
|
|
def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None):
|
|
super(BertAbs, self).__init__(args)
|
|
self.args = args
|
|
self.bert = Bert()
|
|
|
|
# If pre-trained weights are passed for Bert, load these.
|
|
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
|
if load_bert_pretrained_extractive:
|
|
self.bert.model.load_state_dict(
|
|
dict([(n[11:], p) for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")]),
|
|
strict=True,
|
|
)
|
|
|
|
self.vocab_size = self.bert.model.config.vocab_size
|
|
|
|
if args.max_pos > 512:
|
|
my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
|
|
my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
|
|
my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
|
None, :
|
|
].repeat(args.max_pos - 512, 1)
|
|
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
|
|
tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
|
|
|
|
tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
|
|
|
|
self.decoder = TransformerDecoder(
|
|
self.args.dec_layers,
|
|
self.args.dec_hidden_size,
|
|
heads=self.args.dec_heads,
|
|
d_ff=self.args.dec_ff_size,
|
|
dropout=self.args.dec_dropout,
|
|
embeddings=tgt_embeddings,
|
|
vocab_size=self.vocab_size,
|
|
)
|
|
|
|
gen_func = nn.LogSoftmax(dim=-1)
|
|
self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func)
|
|
self.generator[0].weight = self.decoder.embeddings.weight
|
|
|
|
load_from_checkpoints = False if checkpoint is None else True
|
|
if load_from_checkpoints:
|
|
self.load_state_dict(checkpoint)
|
|
|
|
def init_weights(self):
|
|
for module in self.decoder.modules():
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
for p in self.generator.parameters():
|
|
if p.dim() > 1:
|
|
xavier_uniform_(p)
|
|
else:
|
|
p.data.zero_()
|
|
|
|
def forward(
|
|
self, encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask,
|
|
):
|
|
encoder_output = self.bert(
|
|
input_ids=encoder_input_ids, token_type_ids=token_type_ids, attention_mask=encoder_attention_mask,
|
|
)
|
|
encoder_hidden_states = encoder_output[0]
|
|
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states)
|
|
decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state)
|
|
return decoder_outputs
|
|
|
|
|
|
class Bert(nn.Module):
|
|
""" This class is not really necessary and should probably disappear.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(Bert, self).__init__()
|
|
config = BertConfig.from_pretrained("bert-base-uncased")
|
|
self.model = BertModel(config)
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
|
|
self.eval()
|
|
with torch.no_grad():
|
|
encoder_outputs, _ = self.model(
|
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs
|
|
)
|
|
return encoder_outputs
|
|
|
|
|
|
class TransformerDecoder(nn.Module):
|
|
"""
|
|
The Transformer decoder from "Attention is All You Need".
|
|
|
|
Args:
|
|
num_layers (int): number of encoder layers.
|
|
d_model (int): size of the model
|
|
heads (int): number of heads
|
|
d_ff (int): size of the inner FF layer
|
|
dropout (float): dropout parameters
|
|
embeddings (:obj:`onmt.modules.Embeddings`):
|
|
embeddings to use, should have positional encodings
|
|
attn_type (str): if using a seperate copy attention
|
|
"""
|
|
|
|
def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size):
|
|
super(TransformerDecoder, self).__init__()
|
|
|
|
# Basic attributes.
|
|
self.decoder_type = "transformer"
|
|
self.num_layers = num_layers
|
|
self.embeddings = embeddings
|
|
self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim)
|
|
|
|
# Build TransformerDecoder.
|
|
self.transformer_layers = nn.ModuleList(
|
|
[TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)]
|
|
)
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
|
|
|
# forward(input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask)
|
|
# def forward(self, input_ids, state, attention_mask=None, memory_lengths=None,
|
|
# step=None, cache=None, encoder_attention_mask=None, encoder_hidden_states=None, memory_masks=None):
|
|
def forward(
|
|
self,
|
|
input_ids,
|
|
encoder_hidden_states=None,
|
|
state=None,
|
|
attention_mask=None,
|
|
memory_lengths=None,
|
|
step=None,
|
|
cache=None,
|
|
encoder_attention_mask=None,
|
|
):
|
|
"""
|
|
See :obj:`onmt.modules.RNNDecoderBase.forward()`
|
|
memory_bank = encoder_hidden_states
|
|
"""
|
|
# Name conversion
|
|
tgt = input_ids
|
|
memory_bank = encoder_hidden_states
|
|
memory_mask = encoder_attention_mask
|
|
|
|
# src_words = state.src
|
|
src_words = state.src
|
|
src_batch, src_len = src_words.size()
|
|
|
|
padding_idx = self.embeddings.padding_idx
|
|
|
|
# Decoder padding mask
|
|
tgt_words = tgt
|
|
tgt_batch, tgt_len = tgt_words.size()
|
|
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
|
|
|
# Encoder padding mask
|
|
if memory_mask is not None:
|
|
src_len = memory_mask.size(-1)
|
|
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len)
|
|
else:
|
|
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len)
|
|
|
|
# Pass through the embeddings
|
|
emb = self.embeddings(input_ids)
|
|
output = self.pos_emb(emb, step)
|
|
assert emb.dim() == 3 # len x batch x embedding_dim
|
|
|
|
if state.cache is None:
|
|
saved_inputs = []
|
|
|
|
for i in range(self.num_layers):
|
|
prev_layer_input = None
|
|
if state.cache is None:
|
|
if state.previous_input is not None:
|
|
prev_layer_input = state.previous_layer_inputs[i]
|
|
|
|
output, all_input = self.transformer_layers[i](
|
|
output,
|
|
memory_bank,
|
|
src_pad_mask,
|
|
tgt_pad_mask,
|
|
previous_input=prev_layer_input,
|
|
layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None,
|
|
step=step,
|
|
)
|
|
if state.cache is None:
|
|
saved_inputs.append(all_input)
|
|
|
|
if state.cache is None:
|
|
saved_inputs = torch.stack(saved_inputs)
|
|
|
|
output = self.layer_norm(output)
|
|
|
|
if state.cache is None:
|
|
state = state.update_state(tgt, saved_inputs)
|
|
|
|
# Decoders in transformers return a tuple. Beam search will fail
|
|
# if we don't follow this convention.
|
|
return output, state # , state
|
|
|
|
def init_decoder_state(self, src, memory_bank, with_cache=False):
|
|
""" Init decoder state """
|
|
state = TransformerDecoderState(src)
|
|
if with_cache:
|
|
state._init_cache(memory_bank, self.num_layers)
|
|
return state
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, dropout, dim, max_len=5000):
|
|
pe = torch.zeros(max_len, dim)
|
|
position = torch.arange(0, max_len).unsqueeze(1)
|
|
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)))
|
|
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
|
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
super(PositionalEncoding, self).__init__()
|
|
self.register_buffer("pe", pe)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
self.dim = dim
|
|
|
|
def forward(self, emb, step=None):
|
|
emb = emb * math.sqrt(self.dim)
|
|
if step:
|
|
emb = emb + self.pe[:, step][:, None, :]
|
|
|
|
else:
|
|
emb = emb + self.pe[:, : emb.size(1)]
|
|
emb = self.dropout(emb)
|
|
return emb
|
|
|
|
def get_emb(self, emb):
|
|
return self.pe[:, : emb.size(1)]
|
|
|
|
|
|
class TransformerDecoderLayer(nn.Module):
|
|
"""
|
|
Args:
|
|
d_model (int): the dimension of keys/values/queries in
|
|
MultiHeadedAttention, also the input size of
|
|
the first-layer of the PositionwiseFeedForward.
|
|
heads (int): the number of heads for MultiHeadedAttention.
|
|
d_ff (int): the second-layer of the PositionwiseFeedForward.
|
|
dropout (float): dropout probability(0-1.0).
|
|
self_attn_type (string): type of self-attention scaled-dot, average
|
|
"""
|
|
|
|
def __init__(self, d_model, heads, d_ff, dropout):
|
|
super(TransformerDecoderLayer, self).__init__()
|
|
|
|
self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
|
|
|
|
self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout)
|
|
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
|
|
self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6)
|
|
self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6)
|
|
self.drop = nn.Dropout(dropout)
|
|
mask = self._get_attn_subsequent_mask(MAX_SIZE)
|
|
# Register self.mask as a buffer in TransformerDecoderLayer, so
|
|
# it gets TransformerDecoderLayer's cuda behavior automatically.
|
|
self.register_buffer("mask", mask)
|
|
|
|
def forward(
|
|
self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None,
|
|
):
|
|
"""
|
|
Args:
|
|
inputs (`FloatTensor`): `[batch_size x 1 x model_dim]`
|
|
memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]`
|
|
src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]`
|
|
tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]`
|
|
|
|
Returns:
|
|
(`FloatTensor`, `FloatTensor`, `FloatTensor`):
|
|
|
|
* output `[batch_size x 1 x model_dim]`
|
|
* attn `[batch_size x 1 x src_len]`
|
|
* all_input `[batch_size x current_step x model_dim]`
|
|
|
|
"""
|
|
dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0)
|
|
input_norm = self.layer_norm_1(inputs)
|
|
all_input = input_norm
|
|
if previous_input is not None:
|
|
all_input = torch.cat((previous_input, input_norm), dim=1)
|
|
dec_mask = None
|
|
|
|
query = self.self_attn(all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type="self",)
|
|
|
|
query = self.drop(query) + inputs
|
|
|
|
query_norm = self.layer_norm_2(query)
|
|
mid = self.context_attn(
|
|
memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type="context",
|
|
)
|
|
output = self.feed_forward(self.drop(mid) + query)
|
|
|
|
return output, all_input
|
|
# return output
|
|
|
|
def _get_attn_subsequent_mask(self, size):
|
|
"""
|
|
Get an attention mask to avoid using the subsequent info.
|
|
|
|
Args:
|
|
size: int
|
|
|
|
Returns:
|
|
(`LongTensor`):
|
|
|
|
* subsequent_mask `[1 x size x size]`
|
|
"""
|
|
attn_shape = (1, size, size)
|
|
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8")
|
|
subsequent_mask = torch.from_numpy(subsequent_mask)
|
|
return subsequent_mask
|
|
|
|
|
|
class MultiHeadedAttention(nn.Module):
|
|
"""
|
|
Multi-Head Attention module from
|
|
"Attention is All You Need"
|
|
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`.
|
|
|
|
Similar to standard `dot` attention but uses
|
|
multiple attention distributions simulataneously
|
|
to select relevant items.
|
|
|
|
.. mermaid::
|
|
|
|
graph BT
|
|
A[key]
|
|
B[value]
|
|
C[query]
|
|
O[output]
|
|
subgraph Attn
|
|
D[Attn 1]
|
|
E[Attn 2]
|
|
F[Attn N]
|
|
end
|
|
A --> D
|
|
C --> D
|
|
A --> E
|
|
C --> E
|
|
A --> F
|
|
C --> F
|
|
D --> O
|
|
E --> O
|
|
F --> O
|
|
B --> O
|
|
|
|
Also includes several additional tricks.
|
|
|
|
Args:
|
|
head_count (int): number of parallel heads
|
|
model_dim (int): the dimension of keys/values/queries,
|
|
must be divisible by head_count
|
|
dropout (float): dropout parameter
|
|
"""
|
|
|
|
def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True):
|
|
assert model_dim % head_count == 0
|
|
self.dim_per_head = model_dim // head_count
|
|
self.model_dim = model_dim
|
|
|
|
super(MultiHeadedAttention, self).__init__()
|
|
self.head_count = head_count
|
|
|
|
self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head)
|
|
self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head)
|
|
self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head)
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.use_final_linear = use_final_linear
|
|
if self.use_final_linear:
|
|
self.final_linear = nn.Linear(model_dim, model_dim)
|
|
|
|
def forward(
|
|
self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None,
|
|
):
|
|
"""
|
|
Compute the context vector and the attention vectors.
|
|
|
|
Args:
|
|
key (`FloatTensor`): set of `key_len`
|
|
key vectors `[batch, key_len, dim]`
|
|
value (`FloatTensor`): set of `key_len`
|
|
value vectors `[batch, key_len, dim]`
|
|
query (`FloatTensor`): set of `query_len`
|
|
query vectors `[batch, query_len, dim]`
|
|
mask: binary mask indicating which keys have
|
|
non-zero attention `[batch, query_len, key_len]`
|
|
Returns:
|
|
(`FloatTensor`, `FloatTensor`) :
|
|
|
|
* output context vectors `[batch, query_len, dim]`
|
|
* one of the attention vectors `[batch, query_len, key_len]`
|
|
"""
|
|
batch_size = key.size(0)
|
|
dim_per_head = self.dim_per_head
|
|
head_count = self.head_count
|
|
key_len = key.size(1)
|
|
query_len = query.size(1)
|
|
|
|
def shape(x):
|
|
""" projection """
|
|
return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2)
|
|
|
|
def unshape(x):
|
|
""" compute context """
|
|
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
|
|
|
|
# 1) Project key, value, and query.
|
|
if layer_cache is not None:
|
|
if type == "self":
|
|
query, key, value = (
|
|
self.linear_query(query),
|
|
self.linear_keys(query),
|
|
self.linear_values(query),
|
|
)
|
|
|
|
key = shape(key)
|
|
value = shape(value)
|
|
|
|
if layer_cache is not None:
|
|
device = key.device
|
|
if layer_cache["self_keys"] is not None:
|
|
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2)
|
|
if layer_cache["self_values"] is not None:
|
|
value = torch.cat((layer_cache["self_values"].to(device), value), dim=2)
|
|
layer_cache["self_keys"] = key
|
|
layer_cache["self_values"] = value
|
|
elif type == "context":
|
|
query = self.linear_query(query)
|
|
if layer_cache is not None:
|
|
if layer_cache["memory_keys"] is None:
|
|
key, value = self.linear_keys(key), self.linear_values(value)
|
|
key = shape(key)
|
|
value = shape(value)
|
|
else:
|
|
key, value = (
|
|
layer_cache["memory_keys"],
|
|
layer_cache["memory_values"],
|
|
)
|
|
layer_cache["memory_keys"] = key
|
|
layer_cache["memory_values"] = value
|
|
else:
|
|
key, value = self.linear_keys(key), self.linear_values(value)
|
|
key = shape(key)
|
|
value = shape(value)
|
|
else:
|
|
key = self.linear_keys(key)
|
|
value = self.linear_values(value)
|
|
query = self.linear_query(query)
|
|
key = shape(key)
|
|
value = shape(value)
|
|
|
|
query = shape(query)
|
|
|
|
key_len = key.size(2)
|
|
query_len = query.size(2)
|
|
|
|
# 2) Calculate and scale scores.
|
|
query = query / math.sqrt(dim_per_head)
|
|
scores = torch.matmul(query, key.transpose(2, 3))
|
|
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(1).expand_as(scores)
|
|
scores = scores.masked_fill(mask, -1e18)
|
|
|
|
# 3) Apply attention dropout and compute context vectors.
|
|
|
|
attn = self.softmax(scores)
|
|
|
|
if not predefined_graph_1 is None:
|
|
attn_masked = attn[:, -1] * predefined_graph_1
|
|
attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9)
|
|
|
|
attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1)
|
|
|
|
drop_attn = self.dropout(attn)
|
|
if self.use_final_linear:
|
|
context = unshape(torch.matmul(drop_attn, value))
|
|
output = self.final_linear(context)
|
|
return output
|
|
else:
|
|
context = torch.matmul(drop_attn, value)
|
|
return context
|
|
|
|
|
|
class DecoderState(object):
|
|
"""Interface for grouping together the current state of a recurrent
|
|
decoder. In the simplest case just represents the hidden state of
|
|
the model. But can also be used for implementing various forms of
|
|
input_feeding and non-recurrent models.
|
|
|
|
Modules need to implement this to utilize beam search decoding.
|
|
"""
|
|
|
|
def detach(self):
|
|
""" Need to document this """
|
|
self.hidden = tuple([_.detach() for _ in self.hidden])
|
|
self.input_feed = self.input_feed.detach()
|
|
|
|
def beam_update(self, idx, positions, beam_size):
|
|
""" Need to document this """
|
|
for e in self._all:
|
|
sizes = e.size()
|
|
br = sizes[1]
|
|
if len(sizes) == 3:
|
|
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx]
|
|
else:
|
|
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx]
|
|
|
|
sent_states.data.copy_(sent_states.data.index_select(1, positions))
|
|
|
|
def map_batch_fn(self, fn):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class TransformerDecoderState(DecoderState):
|
|
""" Transformer Decoder state base class """
|
|
|
|
def __init__(self, src):
|
|
"""
|
|
Args:
|
|
src (FloatTensor): a sequence of source words tensors
|
|
with optional feature tensors, of size (len x batch).
|
|
"""
|
|
self.src = src
|
|
self.previous_input = None
|
|
self.previous_layer_inputs = None
|
|
self.cache = None
|
|
|
|
@property
|
|
def _all(self):
|
|
"""
|
|
Contains attributes that need to be updated in self.beam_update().
|
|
"""
|
|
if self.previous_input is not None and self.previous_layer_inputs is not None:
|
|
return (self.previous_input, self.previous_layer_inputs, self.src)
|
|
else:
|
|
return (self.src,)
|
|
|
|
def detach(self):
|
|
if self.previous_input is not None:
|
|
self.previous_input = self.previous_input.detach()
|
|
if self.previous_layer_inputs is not None:
|
|
self.previous_layer_inputs = self.previous_layer_inputs.detach()
|
|
self.src = self.src.detach()
|
|
|
|
def update_state(self, new_input, previous_layer_inputs):
|
|
state = TransformerDecoderState(self.src)
|
|
state.previous_input = new_input
|
|
state.previous_layer_inputs = previous_layer_inputs
|
|
return state
|
|
|
|
def _init_cache(self, memory_bank, num_layers):
|
|
self.cache = {}
|
|
|
|
for l in range(num_layers):
|
|
layer_cache = {"memory_keys": None, "memory_values": None}
|
|
layer_cache["self_keys"] = None
|
|
layer_cache["self_values"] = None
|
|
self.cache["layer_{}".format(l)] = layer_cache
|
|
|
|
def repeat_beam_size_times(self, beam_size):
|
|
""" Repeat beam_size times along batch dimension. """
|
|
self.src = self.src.data.repeat(1, beam_size, 1)
|
|
|
|
def map_batch_fn(self, fn):
|
|
def _recursive_map(struct, batch_dim=0):
|
|
for k, v in struct.items():
|
|
if v is not None:
|
|
if isinstance(v, dict):
|
|
_recursive_map(v)
|
|
else:
|
|
struct[k] = fn(v, batch_dim)
|
|
|
|
self.src = fn(self.src, 0)
|
|
if self.cache is not None:
|
|
_recursive_map(self.cache)
|
|
|
|
|
|
def gelu(x):
|
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module):
|
|
""" A two-layer Feed-Forward-Network with residual layer norm.
|
|
|
|
Args:
|
|
d_model (int): the size of input for the first-layer of the FFN.
|
|
d_ff (int): the hidden layer size of the second-layer
|
|
of the FNN.
|
|
dropout (float): dropout probability in :math:`[0, 1)`.
|
|
"""
|
|
|
|
def __init__(self, d_model, d_ff, dropout=0.1):
|
|
super(PositionwiseFeedForward, self).__init__()
|
|
self.w_1 = nn.Linear(d_model, d_ff)
|
|
self.w_2 = nn.Linear(d_ff, d_model)
|
|
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
|
self.actv = gelu
|
|
self.dropout_1 = nn.Dropout(dropout)
|
|
self.dropout_2 = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x))))
|
|
output = self.dropout_2(self.w_2(inter))
|
|
return output + x
|
|
|
|
|
|
#
|
|
# TRANSLATOR
|
|
# The following code is used to generate summaries using the
|
|
# pre-trained weights and beam search.
|
|
#
|
|
|
|
|
|
def build_predictor(args, tokenizer, symbols, model, logger=None):
|
|
# we should be able to refactor the global scorer a lot
|
|
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu")
|
|
translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger)
|
|
return translator
|
|
|
|
|
|
class GNMTGlobalScorer(object):
|
|
"""
|
|
NMT re-ranking score from
|
|
"Google's Neural Machine Translation System" :cite:`wu2016google`
|
|
|
|
Args:
|
|
alpha (float): length parameter
|
|
beta (float): coverage parameter
|
|
"""
|
|
|
|
def __init__(self, alpha, length_penalty):
|
|
self.alpha = alpha
|
|
penalty_builder = PenaltyBuilder(length_penalty)
|
|
self.length_penalty = penalty_builder.length_penalty()
|
|
|
|
def score(self, beam, logprobs):
|
|
"""
|
|
Rescores a prediction based on penalty functions
|
|
"""
|
|
normalized_probs = self.length_penalty(beam, logprobs, self.alpha)
|
|
return normalized_probs
|
|
|
|
|
|
class PenaltyBuilder(object):
|
|
"""
|
|
Returns the Length and Coverage Penalty function for Beam Search.
|
|
|
|
Args:
|
|
length_pen (str): option name of length pen
|
|
cov_pen (str): option name of cov pen
|
|
"""
|
|
|
|
def __init__(self, length_pen):
|
|
self.length_pen = length_pen
|
|
|
|
def length_penalty(self):
|
|
if self.length_pen == "wu":
|
|
return self.length_wu
|
|
elif self.length_pen == "avg":
|
|
return self.length_average
|
|
else:
|
|
return self.length_none
|
|
|
|
"""
|
|
Below are all the different penalty terms implemented so far
|
|
"""
|
|
|
|
def length_wu(self, beam, logprobs, alpha=0.0):
|
|
"""
|
|
NMT length re-ranking score from
|
|
"Google's Neural Machine Translation System" :cite:`wu2016google`.
|
|
"""
|
|
|
|
modifier = ((5 + len(beam.next_ys)) ** alpha) / ((5 + 1) ** alpha)
|
|
return logprobs / modifier
|
|
|
|
def length_average(self, beam, logprobs, alpha=0.0):
|
|
"""
|
|
Returns the average probability of tokens in a sequence.
|
|
"""
|
|
return logprobs / len(beam.next_ys)
|
|
|
|
def length_none(self, beam, logprobs, alpha=0.0, beta=0.0):
|
|
"""
|
|
Returns unmodified scores.
|
|
"""
|
|
return logprobs
|
|
|
|
|
|
class Translator(object):
|
|
"""
|
|
Uses a model to translate a batch of sentences.
|
|
|
|
Args:
|
|
model (:obj:`onmt.modules.NMTModel`):
|
|
NMT model to use for translation
|
|
fields (dict of Fields): data fields
|
|
beam_size (int): size of beam to use
|
|
n_best (int): number of translations produced
|
|
max_length (int): maximum length output to produce
|
|
global_scores (:obj:`GlobalScorer`):
|
|
object to rescore final translations
|
|
copy_attn (bool): use copy attention during translation
|
|
beam_trace (bool): trace beam search for debugging
|
|
logger(logging.Logger): logger.
|
|
"""
|
|
|
|
def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None):
|
|
self.logger = logger
|
|
|
|
self.args = args
|
|
self.model = model
|
|
self.generator = self.model.generator
|
|
self.vocab = vocab
|
|
self.symbols = symbols
|
|
self.start_token = symbols["BOS"]
|
|
self.end_token = symbols["EOS"]
|
|
|
|
self.global_scorer = global_scorer
|
|
self.beam_size = args.beam_size
|
|
self.min_length = args.min_length
|
|
self.max_length = args.max_length
|
|
|
|
def translate(self, batch, step, attn_debug=False):
|
|
""" Generates summaries from one batch of data.
|
|
"""
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
batch_data = self.translate_batch(batch)
|
|
translations = self.from_batch(batch_data)
|
|
return translations
|
|
|
|
def translate_batch(self, batch, fast=False):
|
|
"""
|
|
Translate a batch of sentences.
|
|
|
|
Mostly a wrapper around :obj:`Beam`.
|
|
|
|
Args:
|
|
batch (:obj:`Batch`): a batch from a dataset object
|
|
data (:obj:`Dataset`): the dataset object
|
|
fast (bool): enables fast beam search (may not support all features)
|
|
|
|
Todo:
|
|
Shouldn't need the original dataset.
|
|
"""
|
|
with torch.no_grad():
|
|
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length)
|
|
|
|
# Where the beam search lives
|
|
# I have no idea why it is being called from the method above
|
|
def _fast_translate_batch(self, batch, max_length, min_length=0):
|
|
""" Beam Search using the encoder inputs contained in `batch`.
|
|
"""
|
|
|
|
# The batch object is funny
|
|
# Instead of just looking at the size of the arguments we encapsulate
|
|
# a size argument.
|
|
# Where is it defined?
|
|
beam_size = self.beam_size
|
|
batch_size = batch.batch_size
|
|
src = batch.src
|
|
segs = batch.segs
|
|
mask_src = batch.mask_src
|
|
|
|
src_features = self.model.bert(src, segs, mask_src)
|
|
dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True)
|
|
device = src_features.device
|
|
|
|
# Tile states and memory beam_size times.
|
|
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
|
|
src_features = tile(src_features, beam_size, dim=0)
|
|
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
|
|
beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device)
|
|
alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device)
|
|
|
|
# Give full probability to the first beam on the first step.
|
|
topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)
|
|
|
|
# Structure that holds finished hypotheses.
|
|
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
|
|
|
|
results = {}
|
|
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
|
|
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
|
|
results["gold_score"] = [0] * batch_size
|
|
results["batch"] = batch
|
|
|
|
for step in range(max_length):
|
|
decoder_input = alive_seq[:, -1].view(1, -1)
|
|
|
|
# Decoder forward.
|
|
decoder_input = decoder_input.transpose(0, 1)
|
|
|
|
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
|
|
|
|
# Generator forward.
|
|
log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
|
|
vocab_size = log_probs.size(-1)
|
|
|
|
if step < min_length:
|
|
log_probs[:, self.end_token] = -1e20
|
|
|
|
# Multiply probs by the beam probability.
|
|
log_probs += topk_log_probs.view(-1).unsqueeze(1)
|
|
|
|
alpha = self.global_scorer.alpha
|
|
length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha
|
|
|
|
# Flatten probs into a list of possibilities.
|
|
curr_scores = log_probs / length_penalty
|
|
|
|
if self.args.block_trigram:
|
|
cur_len = alive_seq.size(1)
|
|
if cur_len > 3:
|
|
for i in range(alive_seq.size(0)):
|
|
fail = False
|
|
words = [int(w) for w in alive_seq[i]]
|
|
words = [self.vocab.ids_to_tokens[w] for w in words]
|
|
words = " ".join(words).replace(" ##", "").split()
|
|
if len(words) <= 3:
|
|
continue
|
|
trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)]
|
|
trigram = tuple(trigrams[-1])
|
|
if trigram in trigrams[:-1]:
|
|
fail = True
|
|
if fail:
|
|
curr_scores[i] = -10e20
|
|
|
|
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
|
|
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
|
|
|
|
# Recover log probs.
|
|
topk_log_probs = topk_scores * length_penalty
|
|
|
|
# Resolve beam origin and true word ids.
|
|
topk_beam_index = topk_ids.div(vocab_size)
|
|
topk_ids = topk_ids.fmod(vocab_size)
|
|
|
|
# Map beam_index to batch_index in the flat representation.
|
|
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
|
select_indices = batch_index.view(-1)
|
|
|
|
# Append last prediction.
|
|
alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1)
|
|
|
|
is_finished = topk_ids.eq(self.end_token)
|
|
if step + 1 == max_length:
|
|
is_finished.fill_(1)
|
|
# End condition is top beam is finished.
|
|
end_condition = is_finished[:, 0].eq(1)
|
|
# Save finished hypotheses.
|
|
if is_finished.any():
|
|
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
|
|
for i in range(is_finished.size(0)):
|
|
b = batch_offset[i]
|
|
if end_condition[i]:
|
|
is_finished[i].fill_(1)
|
|
finished_hyp = is_finished[i].nonzero().view(-1)
|
|
# Store finished hypotheses for this batch.
|
|
for j in finished_hyp:
|
|
hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:]))
|
|
# If the batch reached the end, save the n_best hypotheses.
|
|
if end_condition[i]:
|
|
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True)
|
|
score, pred = best_hyp[0]
|
|
|
|
results["scores"][b].append(score)
|
|
results["predictions"][b].append(pred)
|
|
non_finished = end_condition.eq(0).nonzero().view(-1)
|
|
# If all sentences are translated, no need to go further.
|
|
if len(non_finished) == 0:
|
|
break
|
|
# Remove finished batches for the next step.
|
|
topk_log_probs = topk_log_probs.index_select(0, non_finished)
|
|
batch_index = batch_index.index_select(0, non_finished)
|
|
batch_offset = batch_offset.index_select(0, non_finished)
|
|
alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1))
|
|
# Reorder states.
|
|
select_indices = batch_index.view(-1)
|
|
src_features = src_features.index_select(0, select_indices)
|
|
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
|
|
|
return results
|
|
|
|
def from_batch(self, translation_batch):
|
|
batch = translation_batch["batch"]
|
|
assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"])
|
|
batch_size = batch.batch_size
|
|
|
|
preds, _, _, tgt_str, src = (
|
|
translation_batch["predictions"],
|
|
translation_batch["scores"],
|
|
translation_batch["gold_score"],
|
|
batch.tgt_str,
|
|
batch.src,
|
|
)
|
|
|
|
translations = []
|
|
for b in range(batch_size):
|
|
pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]])
|
|
pred_sents = " ".join(pred_sents).replace(" ##", "")
|
|
gold_sent = " ".join(tgt_str[b].split())
|
|
raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500]
|
|
raw_src = " ".join(raw_src)
|
|
translation = (pred_sents, gold_sent, raw_src)
|
|
translations.append(translation)
|
|
|
|
return translations
|
|
|
|
|
|
def tile(x, count, dim=0):
|
|
"""
|
|
Tiles x on dimension dim count times.
|
|
"""
|
|
perm = list(range(len(x.size())))
|
|
if dim != 0:
|
|
perm[0], perm[dim] = perm[dim], perm[0]
|
|
x = x.permute(perm).contiguous()
|
|
out_size = list(x.size())
|
|
out_size[0] *= count
|
|
batch = x.size(0)
|
|
x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size)
|
|
if dim != 0:
|
|
x = x.permute(perm).contiguous()
|
|
return x
|
|
|
|
|
|
#
|
|
# Optimizer for training. We keep this here in case we want to add
|
|
# a finetuning script.
|
|
#
|
|
|
|
|
|
class BertSumOptimizer(object):
|
|
""" Specific optimizer for BertSum.
|
|
|
|
As described in [1], the authors fine-tune BertSum for abstractive
|
|
summarization using two Adam Optimizers with different warm-up steps and
|
|
learning rate. They also use a custom learning rate scheduler.
|
|
|
|
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
|
arXiv preprint arXiv:1908.08345 (2019).
|
|
"""
|
|
|
|
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
|
|
self.encoder = model.encoder
|
|
self.decoder = model.decoder
|
|
self.lr = lr
|
|
self.warmup_steps = warmup_steps
|
|
|
|
self.optimizers = {
|
|
"encoder": torch.optim.Adam(
|
|
model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps,
|
|
),
|
|
"decoder": torch.optim.Adam(
|
|
model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps,
|
|
),
|
|
}
|
|
|
|
self._step = 0
|
|
self.current_learning_rates = {}
|
|
|
|
def _update_rate(self, stack):
|
|
return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5))
|
|
|
|
def zero_grad(self):
|
|
self.optimizer_decoder.zero_grad()
|
|
self.optimizer_encoder.zero_grad()
|
|
|
|
def step(self):
|
|
self._step += 1
|
|
for stack, optimizer in self.optimizers.items():
|
|
new_rate = self._update_rate(stack)
|
|
for param_group in optimizer.param_groups:
|
|
param_group["lr"] = new_rate
|
|
optimizer.step()
|
|
self.current_learning_rates[stack] = new_rate
|