From 2403a6659859ad18a9f20e1c2e84179718d8dfd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Sat, 23 Nov 2019 00:18:44 +0100 Subject: [PATCH] give transformers API to BertAbs --- ..._original_pytorch_checkpoint_to_pytorch.py | 161 +++ .../summarization/configuration_bertabs.py | 141 ++ ...ert_bertabs_original_pytorch_checkpoint.py | 162 +++ examples/summarization/modeling_bertabs.py | 1250 +++++++++++++++++ examples/summarization/run_summarization.py | 271 ++++ .../utils_summarization.py | 56 +- .../utils_summarization_test.py | 17 +- ..._original_pytorch_checkpoint_to_pytorch.py | 158 +++ transformers/generate/beam_search.py | 26 +- 9 files changed, 2188 insertions(+), 54 deletions(-) create mode 100644 examples/convert_bertextabs_original_pytorch_checkpoint_to_pytorch.py create mode 100644 examples/summarization/configuration_bertabs.py create mode 100644 examples/summarization/convert_bertabs_original_pytorch_checkpoint.py create mode 100644 examples/summarization/modeling_bertabs.py create mode 100644 examples/summarization/run_summarization.py rename examples/{ => summarization}/utils_summarization.py (77%) rename examples/{ => summarization}/utils_summarization_test.py (88%) create mode 100644 transformers/convert_bertextabs_original_pytorch_checkpoint_to_pytorch.py diff --git a/examples/convert_bertextabs_original_pytorch_checkpoint_to_pytorch.py b/examples/convert_bertextabs_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 00000000000..c245d0eae55 --- /dev/null +++ b/examples/convert_bertextabs_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Convert BertExtAbs's checkpoints """ + +import argparse +from collections import namedtuple +import logging +import pdb +import torch + +from models.model_builder import AbsSummarizer # The authors' implementation +from model_bertabs import BertAbsSummarizer + +from transformers import BertTokenizer + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +SAMPLE_TEXT = 'Hello world! cécé herlolip' + + +BertAbsConfig = namedtuple( + "BertAbsConfig", + ["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"], +) + + +def convert_bertabs_checkpoints(path_to_checkpoints, dump_path): + """ Copy/paste and tweak the pre-trained weights provided by the creators + of BertAbs for the internal architecture. + """ + + # Instantiate the authors' model with the pre-trained weights + config = BertAbsConfig( + temp_dir=".", + finetune_bert=False, + large=False, + share_emb=True, + use_bert_emb=False, + encoder="bert", + max_pos=512, + enc_layers=6, + enc_hidden_size=512, + enc_heads=8, + enc_ff_size=512, + enc_dropout=0.2, + dec_layers=6, + dec_hidden_size=768, + dec_heads=8, + dec_ff_size=2048, + dec_dropout=0.2, + ) + checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage) + original = AbsSummarizer(config, torch.device("cpu"), checkpoints) + original.eval() + + new_model = BertAbsSummarizer(config, torch.device("cpu")) + new_model.eval() + + # ------------------- + # Convert the weights + # ------------------- + + logging.info("convert the model") + new_model.encoder.load_state_dict(original.bert.state_dict()) + + new_model.decoder.generator.load_state_dict(original.generator.state_dict()) + new_model.decoder.embeddings.load_state_dict(original.decoder.embeddings.state_dict()) + new_model.decoder.pos_emb.load_state_dict(original.decoder.pos_emb.state_dict()) + new_model.decoder.transformer_layers.load_state_dict(original.decoder.transformer_layers.state_dict()) + new_model.decoder.layer_norm.load_state_dict(original.decoder.layer_norm.state_dict()) + + # ---------------------------------- + # Make sure the outpus are identical + # ---------------------------------- + + logging.info("Make sure that the models' outputs are identical") + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + + # prepare the model inputs + encoder_input_ids = tokenizer.encode("This is sample éàalj'-.") + encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids))) + encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0) + decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.") + decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids))) + decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0) + + # failsafe to make sure the weights reset does not affect the + # loaded weights. + assert torch.max(torch.abs(original.generator[0].weight - new_model.decoder.generator[0].weight)) == 0 + + # forward pass + src = encoder_input_ids + tgt = decoder_input_ids + segs = token_type_ids = None + clss = None + mask_src = encoder_attention_mask = None + mask_tgt = decoder_attention_mask = None + mask_cls = None + + # The original model does not apply the geneator layer immediatly but rather in + # the beam search (where it combines softmax + linear layer). Since we already + # apply the softmax in our generation process we only apply the linear layer here. + # We make sure that the outputs of the full stack are identical + output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0] + output_original_model = original.generator(output_original_model) + + output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0] + output_converted_model = torch.nn.functional.log_softmax(output_converted_model, dim=-1) + + maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item() + print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference)) + + are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3) + if are_identical: + logging.info("all weights are equal up to 1e-3") + else: + raise ValueError("the weights are different. The new model is likely different from the original one.") + + # The model has been saved with torch.save(model) and this is bound to the exact + # directory structure. We save the state_dict instead. + logging.info("saving the model's state dictionary") + torch.save(new_model.state_dict(), "bert-ext-abs.pt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bertabs_checkpoint_path", + default=None, + type=str, + required=True, + help="Path the official PyTorch dump.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.", + ) + args = parser.parse_args() + + convert_bertabs_checkpoints( + args.bertabs_checkpoint_path, + args.pytorch_dump_folder_path, + ) diff --git a/examples/summarization/configuration_bertabs.py b/examples/summarization/configuration_bertabs.py new file mode 100644 index 00000000000..ff3171f9a80 --- /dev/null +++ b/examples/summarization/configuration_bertabs.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2019 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BertAbs configuration """ +import json +import logging +import sys + +from transformers import PretrainedConfig + + +logger = logging.getLogger(__name__) + + +BERTABS_FINETUNED_CONFIG_MAP = { + "bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json", +} + + +class BertAbsConfig(PretrainedConfig): + r""" Class to store the configuration of the BertAbs model. + + Arguments: + temp_dir: string + Unused in the current situation. Kept for compatibility but will be removed. + finetune_bert: bool + Whether to fine-tune the model or not. Will be kept for reference + in case we want to add the possibility to fine-tune the model. + large: bool + Whether to use bert-large as a base. + share_emb: book + Whether the embeddings are shared between the encoder and decoder. + encoder: string + Not clear what this does. Leave to "bert" for pre-trained weights. + max_pos: int + The maximum sequence length that this model will be used with. + enc_layer: int + The numner of hidden layers in the Transformer encoder. + enc_hidden_size: int + The size of the encoder's layers. + enc_heads: int + The number of attention heads for each attention layer in the encoder. + enc_ff_size: int + The size of the encoder's feed-forward layers. + enc_dropout: int + The dropout probabilitiy for all fully connected layers in the + embeddings, layers, pooler and also the attention probabilities in + the encoder. + dec_layer: int + The numner of hidden layers in the decoder. + dec_hidden_size: int + The size of the decoder's layers. + dec_heads: int + The number of attention heads for each attention layer in the decoder. + dec_ff_size: int + The size of the decoder's feed-forward layers. + dec_dropout: int + The dropout probabilitiy for all fully connected layers in the + embeddings, layers, pooler and also the attention probabilities in + the decoder. + """ + + pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP + + def __init__( + self, + vocab_size_or_config_json_file=30522, + temp_dir=".", + finetune_bert=False, + large=False, + share_emb=True, + encoder="bert", + max_pos=512, + enc_layers=6, + enc_hidden_size=512, + enc_heads=8, + enc_ff_size=512, + enc_dropout=0.2, + dec_layers=6, + dec_hidden_size=768, + dec_heads=8, + dec_ff_size=2048, + dec_dropout=0.2, + **kwargs, + ): + super(BertAbsConfig, self).__init__(**kwargs) + + if self._input_is_path_to_json(vocab_size_or_config_json_file): + path_to_json = vocab_size_or_config_json_file + with open(path_to_json, "r", encoding="utf-8") as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.temp_dir = temp_dir + self.finetune_bert = finetune_bert + self.large = large + self.vocab_size = vocab_size_or_config_json_file + self.max_pos = max_pos + + self.encoder = encoder + self.enc_layers = enc_layers + self.enc_hidden_size = enc_hidden_size + self.enc_heads = enc_heads + self.enc_ff_size = enc_ff_size + self.enc_dropout = enc_dropout + + self.share_emb = share_emb + + self.dec_layers = dec_layers + self.dec_hidden_size = dec_hidden_size + self.dec_heads = dec_heads + self.dec_ff_size = dec_ff_size + self.dec_dropout = dec_dropout + else: + raise ValueError( + "First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)" + ) + + def _input_is_path_to_json(self, first_argument): + """ Checks whether the first argument passed to config + is the path to a JSON file that contains the config. + """ + is_python_2 = sys.version_info[0] == 2 + if is_python_2: + return isinstance(first_argument, unicode) + else: + return isinstance(first_argument, str) diff --git a/examples/summarization/convert_bertabs_original_pytorch_checkpoint.py b/examples/summarization/convert_bertabs_original_pytorch_checkpoint.py new file mode 100644 index 00000000000..786a29ef136 --- /dev/null +++ b/examples/summarization/convert_bertabs_original_pytorch_checkpoint.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Convert BertExtAbs's checkpoints + +The file currently does not do much as we ended up copying the exact model +structure, but I leave it here in case we ever want to refactor the model. +""" + +import argparse +from collections import namedtuple +import logging +import torch + +from models.model_builder import AbsSummarizer # The authors' implementation +from model_bertabs import BertAbsSummarizer + +from transformers import BertTokenizer + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +SAMPLE_TEXT = 'Hello world! cécé herlolip' + + +BertAbsConfig = namedtuple( + "BertAbsConfig", + ["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"], +) + + +def convert_bertabs_checkpoints(path_to_checkpoints, dump_path): + """ Copy/paste and tweak the pre-trained weights provided by the creators + of BertAbs for the internal architecture. + """ + + # Instantiate the authors' model with the pre-trained weights + config = BertAbsConfig( + temp_dir=".", + finetune_bert=False, + large=False, + share_emb=True, + use_bert_emb=False, + encoder="bert", + max_pos=512, + enc_layers=6, + enc_hidden_size=512, + enc_heads=8, + enc_ff_size=512, + enc_dropout=0.2, + dec_layers=6, + dec_hidden_size=768, + dec_heads=8, + dec_ff_size=2048, + dec_dropout=0.2, + ) + checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage) + original = AbsSummarizer(config, torch.device("cpu"), checkpoints) + original.eval() + + new_model = BertAbsSummarizer(config, torch.device("cpu")) + new_model.eval() + + # ------------------- + # Convert the weights + # ------------------- + + logging.info("convert the model") + new_model.bert.load_state_dict(original.bert.state_dict()) + new_model.decoder.load_state_dict(original.decoder.state_dict()) + new_model.generator.load_state_dict(original.generator.state_dict()) + + # ---------------------------------- + # Make sure the outpus are identical + # ---------------------------------- + + logging.info("Make sure that the models' outputs are identical") + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + + # prepare the model inputs + encoder_input_ids = tokenizer.encode("This is sample éàalj'-.") + encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids))) + encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0) + decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.") + decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids))) + decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0) + + # failsafe to make sure the weights reset does not affect the + # loaded weights. + assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0 + + # forward pass + src = encoder_input_ids + tgt = decoder_input_ids + segs = token_type_ids = None + clss = None + mask_src = encoder_attention_mask = None + mask_tgt = decoder_attention_mask = None + mask_cls = None + + # The original model does not apply the geneator layer immediatly but rather in + # the beam search (where it combines softmax + linear layer). Since we already + # apply the softmax in our generation process we only apply the linear layer here. + # We make sure that the outputs of the full stack are identical + output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0] + output_original_generator = original.generator(output_original_model) + + output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0] + output_converted_generator = new_model.generator(output_converted_model) + + maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item() + print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference)) + maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item() + print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference)) + + are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3) + if are_identical: + logging.info("all weights are equal up to 1e-3") + else: + raise ValueError("the weights are different. The new model is likely different from the original one.") + + # The model has been saved with torch.save(model) and this is bound to the exact + # directory structure. We save the state_dict instead. + logging.info("saving the model's state dictionary") + torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bertabs_checkpoint_path", + default=None, + type=str, + required=True, + help="Path the official PyTorch dump.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.", + ) + args = parser.parse_args() + + convert_bertabs_checkpoints( + args.bertabs_checkpoint_path, + args.pytorch_dump_folder_path, + ) diff --git a/examples/summarization/modeling_bertabs.py b/examples/summarization/modeling_bertabs.py new file mode 100644 index 00000000000..0189a2ad2b0 --- /dev/null +++ b/examples/summarization/modeling_bertabs.py @@ -0,0 +1,1250 @@ +# MIT License + +# Copyright (c) 2019 Yang Liu + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import copy +import math +import shutil +import time +import os + +import numpy as np +import torch +from torch import nn +from torch.nn.init import xavier_uniform_ + +from transformers import BertModel, BertConfig, PreTrainedModel + +from configuration_bertabs import BertAbsConfig + + +MAX_SIZE = 5000 + +BERTABS_FINETUNED_MODEL_MAP = { + "bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin", +} + + +class BertAbsPreTrainedModel(PreTrainedModel): + config_class = BertAbsConfig + pretrained_model_archive_map = BERTABS_FINETUNED_MODEL_MAP + load_tf_weights = False + base_model_prefix = "bert" + + +class BertAbs(BertAbsPreTrainedModel): + def __init__(self, args, checkpoint=None, bert_extractive_checkpoint=None): + super(BertAbs, self).__init__(args) + self.args = args + self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) + + # If pre-trained weights are passed for Bert, load these. + load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False + if load_bert_pretrained_extractive: + self.bert.model.load_state_dict( + dict( + [ + (n[11:], p) + for n, p in bert_extractive_checkpoint.items() + if n.startswith("bert.model") + ] + ), + strict=True, + ) + + if args.encoder == "baseline": + bert_config = BertConfig( + self.bert.model.config.vocab_size, + hidden_size=args.enc_hidden_size, + num_hidden_layers=args.enc_layers, + num_attention_heads=8, + intermediate_size=args.enc_ff_size, + hidden_dropout_prob=args.enc_dropout, + attention_probs_dropout_prob=args.enc_dropout, + ) + self.bert.model = BertModel(bert_config) + + self.vocab_size = self.bert.model.config.vocab_size + + if args.max_pos > 512: + my_pos_embeddings = nn.Embedding( + args.max_pos, self.bert.model.config.hidden_size + ) + my_pos_embeddings.weight.data[ + :512 + ] = self.bert.model.embeddings.position_embeddings.weight.data + my_pos_embeddings.weight.data[ + 512: + ] = self.bert.model.embeddings.position_embeddings.weight.data[-1][ + None, : + ].repeat( + args.max_pos - 512, 1 + ) + self.bert.model.embeddings.position_embeddings = my_pos_embeddings + tgt_embeddings = nn.Embedding( + self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0 + ) + if self.args.share_emb: + tgt_embeddings.weight = copy.deepcopy( + self.bert.model.embeddings.word_embeddings.weight + ) + + self.decoder = TransformerDecoder( + self.args.dec_layers, + self.args.dec_hidden_size, + heads=self.args.dec_heads, + d_ff=self.args.dec_ff_size, + dropout=self.args.dec_dropout, + embeddings=tgt_embeddings, + vocab_size=self.vocab_size, + ) + + gen_func = nn.LogSoftmax(dim=-1) + self.generator = nn.Sequential( + nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func + ) + self.generator[0].weight = self.decoder.embeddings.weight + + load_from_checkpoints = False if checkpoint is None else True + if load_from_checkpoints: + self.load_state_dict(checkpoint) + + def init_weights(self): + for module in self.decoder.modules(): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + for p in self.generator.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + else: + p.data.zero_() + + def maybe_tie_embeddings(self, args): + if args.use_bert_emb: + tgt_embeddings = nn.Embedding( + self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0 + ) + tgt_embeddings.weight = copy.deepcopy( + self.bert.model.embeddings.word_embeddings.weight + ) + self.decoder.embeddings = tgt_embeddings + + def forward( + self, + encoder_input_ids, + decoder_input_ids, + token_type_ids, + encoder_attention_mask, + decoder_attention_mask, + ): + encoder_output = self.bert( + input_ids=encoder_input_ids, + token_type_ids=token_type_ids, + attention_mask=encoder_attention_mask, + ) + encoder_hidden_states = encoder_output[0] + dec_state = self.decoder.init_decoder_state( + encoder_input_ids, encoder_hidden_states + ) + decoder_outputs, _ = self.decoder( + decoder_input_ids[:, :-1], encoder_hidden_states, dec_state + ) + return decoder_outputs + + +class Bert(nn.Module): + """ This class is not really necessary and should probably disappear. + """ + + def __init__(self, large, temp_dir, finetune=False): + super(Bert, self).__init__() + if large: + self.model = BertModel.from_pretrained("bert-large-uncased", cache_dir=temp_dir) + else: + self.model = BertModel.from_pretrained("bert-base-uncased", cache_dir=temp_dir) + + self.finetune = finetune + + def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs): + self.eval() + with torch.no_grad(): + encoder_outputs, _ = self.model( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + **kwargs + ) + return encoder_outputs + + +class TransformerDecoder(nn.Module): + """ + The Transformer decoder from "Attention is All You Need". + + Args: + num_layers (int): number of encoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + dropout (float): dropout parameters + embeddings (:obj:`onmt.modules.Embeddings`): + embeddings to use, should have positional encodings + attn_type (str): if using a seperate copy attention + """ + + def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, vocab_size): + super(TransformerDecoder, self).__init__() + + # Basic attributes. + self.decoder_type = "transformer" + self.num_layers = num_layers + self.embeddings = embeddings + self.pos_emb = PositionalEncoding(dropout, self.embeddings.embedding_dim) + + # Build TransformerDecoder. + self.transformer_layers = nn.ModuleList( + [ + TransformerDecoderLayer(d_model, heads, d_ff, dropout) + for _ in range(num_layers) + ] + ) + + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + # forward(input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask) + # def forward(self, input_ids, state, attention_mask=None, memory_lengths=None, + # step=None, cache=None, encoder_attention_mask=None, encoder_hidden_states=None, memory_masks=None): + def forward( + self, + input_ids, + encoder_hidden_states=None, + state=None, + attention_mask=None, + memory_lengths=None, + step=None, + cache=None, + encoder_attention_mask=None, + ): + """ + See :obj:`onmt.modules.RNNDecoderBase.forward()` + memory_bank = encoder_hidden_states + """ + # Name conversion + tgt = input_ids + memory_bank = encoder_hidden_states + memory_mask = encoder_attention_mask + + # src_words = state.src + src_words = state.src + src_batch, src_len = src_words.size() + + padding_idx = self.embeddings.padding_idx + + # Decoder padding mask + tgt_words = tgt + tgt_batch, tgt_len = tgt_words.size() + tgt_pad_mask = ( + tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len) + ) + + # Encoder padding mask + if memory_mask is not None: + src_len = memory_mask.size(-1) + src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len) + else: + src_pad_mask = ( + src_words.data.eq(padding_idx) + .unsqueeze(1) + .expand(src_batch, tgt_len, src_len) + ) + + # Pass through the embeddings + emb = self.embeddings(input_ids) + output = self.pos_emb(emb, step) + assert emb.dim() == 3 # len x batch x embedding_dim + + if state.cache is None: + saved_inputs = [] + + for i in range(self.num_layers): + prev_layer_input = None + if state.cache is None: + if state.previous_input is not None: + prev_layer_input = state.previous_layer_inputs[i] + + output, all_input = self.transformer_layers[i]( + output, + memory_bank, + src_pad_mask, + tgt_pad_mask, + previous_input=prev_layer_input, + layer_cache=state.cache["layer_{}".format(i)] + if state.cache is not None + else None, + step=step, + ) + if state.cache is None: + saved_inputs.append(all_input) + + if state.cache is None: + saved_inputs = torch.stack(saved_inputs) + + output = self.layer_norm(output) + + if state.cache is None: + state = state.update_state(tgt, saved_inputs) + + # Decoders in transformers return a tuple. Beam search will fail + # if we don't follow this convention. + return output, state # , state + + def init_decoder_state(self, src, memory_bank, with_cache=False): + """ Init decoder state """ + state = TransformerDecoderState(src) + if with_cache: + state._init_cache(memory_bank, self.num_layers) + return state + + +class PositionalEncoding(nn.Module): + def __init__(self, dropout, dim, max_len=5000): + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + (torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)) + ) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0) + super(PositionalEncoding, self).__init__() + self.register_buffer("pe", pe) + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + + def forward(self, emb, step=None): + emb = emb * math.sqrt(self.dim) + if step: + emb = emb + self.pe[:, step][:, None, :] + + else: + emb = emb + self.pe[:, : emb.size(1)] + emb = self.dropout(emb) + return emb + + def get_emb(self, emb): + return self.pe[:, : emb.size(1)] + + +class TransformerDecoderLayer(nn.Module): + """ + Args: + d_model (int): the dimension of keys/values/queries in + MultiHeadedAttention, also the input size of + the first-layer of the PositionwiseFeedForward. + heads (int): the number of heads for MultiHeadedAttention. + d_ff (int): the second-layer of the PositionwiseFeedForward. + dropout (float): dropout probability(0-1.0). + self_attn_type (string): type of self-attention scaled-dot, average + """ + + def __init__(self, d_model, heads, d_ff, dropout): + super(TransformerDecoderLayer, self).__init__() + + self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) + + self.context_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) + self.layer_norm_1 = nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm_2 = nn.LayerNorm(d_model, eps=1e-6) + self.drop = nn.Dropout(dropout) + mask = self._get_attn_subsequent_mask(MAX_SIZE) + # Register self.mask as a buffer in TransformerDecoderLayer, so + # it gets TransformerDecoderLayer's cuda behavior automatically. + self.register_buffer("mask", mask) + + def forward( + self, + inputs, + memory_bank, + src_pad_mask, + tgt_pad_mask, + previous_input=None, + layer_cache=None, + step=None, + ): + """ + Args: + inputs (`FloatTensor`): `[batch_size x 1 x model_dim]` + memory_bank (`FloatTensor`): `[batch_size x src_len x model_dim]` + src_pad_mask (`LongTensor`): `[batch_size x 1 x src_len]` + tgt_pad_mask (`LongTensor`): `[batch_size x 1 x 1]` + + Returns: + (`FloatTensor`, `FloatTensor`, `FloatTensor`): + + * output `[batch_size x 1 x model_dim]` + * attn `[batch_size x 1 x src_len]` + * all_input `[batch_size x current_step x model_dim]` + + """ + dec_mask = torch.gt( + tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0 + ) + input_norm = self.layer_norm_1(inputs) + all_input = input_norm + if previous_input is not None: + all_input = torch.cat((previous_input, input_norm), dim=1) + dec_mask = None + + query = self.self_attn( + all_input, + all_input, + input_norm, + mask=dec_mask, + layer_cache=layer_cache, + type="self", + ) + + query = self.drop(query) + inputs + + query_norm = self.layer_norm_2(query) + mid = self.context_attn( + memory_bank, + memory_bank, + query_norm, + mask=src_pad_mask, + layer_cache=layer_cache, + type="context", + ) + output = self.feed_forward(self.drop(mid) + query) + + return output, all_input + # return output + + def _get_attn_subsequent_mask(self, size): + """ + Get an attention mask to avoid using the subsequent info. + + Args: + size: int + + Returns: + (`LongTensor`): + + * subsequent_mask `[1 x size x size]` + """ + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") + subsequent_mask = torch.from_numpy(subsequent_mask) + return subsequent_mask + + +class MultiHeadedAttention(nn.Module): + """ + Multi-Head Attention module from + "Attention is All You Need" + :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. + + Similar to standard `dot` attention but uses + multiple attention distributions simulataneously + to select relevant items. + + .. mermaid:: + + graph BT + A[key] + B[value] + C[query] + O[output] + subgraph Attn + D[Attn 1] + E[Attn 2] + F[Attn N] + end + A --> D + C --> D + A --> E + C --> E + A --> F + C --> F + D --> O + E --> O + F --> O + B --> O + + Also includes several additional tricks. + + Args: + head_count (int): number of parallel heads + model_dim (int): the dimension of keys/values/queries, + must be divisible by head_count + dropout (float): dropout parameter + """ + + def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): + assert model_dim % head_count == 0 + self.dim_per_head = model_dim // head_count + self.model_dim = model_dim + + super(MultiHeadedAttention, self).__init__() + self.head_count = head_count + + self.linear_keys = nn.Linear(model_dim, head_count * self.dim_per_head) + self.linear_values = nn.Linear(model_dim, head_count * self.dim_per_head) + self.linear_query = nn.Linear(model_dim, head_count * self.dim_per_head) + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.use_final_linear = use_final_linear + if self.use_final_linear: + self.final_linear = nn.Linear(model_dim, model_dim) + + def forward( + self, + key, + value, + query, + mask=None, + layer_cache=None, + type=None, + predefined_graph_1=None, + ): + """ + Compute the context vector and the attention vectors. + + Args: + key (`FloatTensor`): set of `key_len` + key vectors `[batch, key_len, dim]` + value (`FloatTensor`): set of `key_len` + value vectors `[batch, key_len, dim]` + query (`FloatTensor`): set of `query_len` + query vectors `[batch, query_len, dim]` + mask: binary mask indicating which keys have + non-zero attention `[batch, query_len, key_len]` + Returns: + (`FloatTensor`, `FloatTensor`) : + + * output context vectors `[batch, query_len, dim]` + * one of the attention vectors `[batch, query_len, key_len]` + """ + batch_size = key.size(0) + dim_per_head = self.dim_per_head + head_count = self.head_count + key_len = key.size(1) + query_len = query.size(1) + + def shape(x): + """ projection """ + return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2) + + def unshape(x): + """ compute context """ + return ( + x.transpose(1, 2) + .contiguous() + .view(batch_size, -1, head_count * dim_per_head) + ) + + # 1) Project key, value, and query. + if layer_cache is not None: + if type == "self": + query, key, value = ( + self.linear_query(query), + self.linear_keys(query), + self.linear_values(query), + ) + + key = shape(key) + value = shape(value) + + if layer_cache is not None: + device = key.device + if layer_cache["self_keys"] is not None: + key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2) + if layer_cache["self_values"] is not None: + value = torch.cat( + (layer_cache["self_values"].to(device), value), dim=2 + ) + layer_cache["self_keys"] = key + layer_cache["self_values"] = value + elif type == "context": + query = self.linear_query(query) + if layer_cache is not None: + if layer_cache["memory_keys"] is None: + key, value = self.linear_keys(key), self.linear_values(value) + key = shape(key) + value = shape(value) + else: + key, value = ( + layer_cache["memory_keys"], + layer_cache["memory_values"], + ) + layer_cache["memory_keys"] = key + layer_cache["memory_values"] = value + else: + key, value = self.linear_keys(key), self.linear_values(value) + key = shape(key) + value = shape(value) + else: + key = self.linear_keys(key) + value = self.linear_values(value) + query = self.linear_query(query) + key = shape(key) + value = shape(value) + + query = shape(query) + + key_len = key.size(2) + query_len = query.size(2) + + # 2) Calculate and scale scores. + query = query / math.sqrt(dim_per_head) + scores = torch.matmul(query, key.transpose(2, 3)) + + if mask is not None: + mask = mask.unsqueeze(1).expand_as(scores) + scores = scores.masked_fill(mask, -1e18) + + # 3) Apply attention dropout and compute context vectors. + + attn = self.softmax(scores) + + if not predefined_graph_1 is None: + attn_masked = attn[:, -1] * predefined_graph_1 + attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) + + attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) + + drop_attn = self.dropout(attn) + if self.use_final_linear: + context = unshape(torch.matmul(drop_attn, value)) + output = self.final_linear(context) + return output + else: + context = torch.matmul(drop_attn, value) + return context + + +class DecoderState(object): + """Interface for grouping together the current state of a recurrent + decoder. In the simplest case just represents the hidden state of + the model. But can also be used for implementing various forms of + input_feeding and non-recurrent models. + + Modules need to implement this to utilize beam search decoding. + """ + + def detach(self): + """ Need to document this """ + self.hidden = tuple([_.detach() for _ in self.hidden]) + self.input_feed = self.input_feed.detach() + + def beam_update(self, idx, positions, beam_size): + """ Need to document this """ + for e in self._all: + sizes = e.size() + br = sizes[1] + if len(sizes) == 3: + sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[ + :, :, idx + ] + else: + sent_states = e.view( + sizes[0], beam_size, br // beam_size, sizes[2], sizes[3] + )[:, :, idx] + + sent_states.data.copy_(sent_states.data.index_select(1, positions)) + + def map_batch_fn(self, fn): + raise NotImplementedError() + + +class TransformerDecoderState(DecoderState): + """ Transformer Decoder state base class """ + + def __init__(self, src): + """ + Args: + src (FloatTensor): a sequence of source words tensors + with optional feature tensors, of size (len x batch). + """ + self.src = src + self.previous_input = None + self.previous_layer_inputs = None + self.cache = None + + @property + def _all(self): + """ + Contains attributes that need to be updated in self.beam_update(). + """ + if self.previous_input is not None and self.previous_layer_inputs is not None: + return (self.previous_input, self.previous_layer_inputs, self.src) + else: + return (self.src,) + + def detach(self): + if self.previous_input is not None: + self.previous_input = self.previous_input.detach() + if self.previous_layer_inputs is not None: + self.previous_layer_inputs = self.previous_layer_inputs.detach() + self.src = self.src.detach() + + def update_state(self, new_input, previous_layer_inputs): + state = TransformerDecoderState(self.src) + state.previous_input = new_input + state.previous_layer_inputs = previous_layer_inputs + return state + + def _init_cache(self, memory_bank, num_layers): + self.cache = {} + + for l in range(num_layers): + layer_cache = {"memory_keys": None, "memory_values": None} + layer_cache["self_keys"] = None + layer_cache["self_values"] = None + self.cache["layer_{}".format(l)] = layer_cache + + def repeat_beam_size_times(self, beam_size): + """ Repeat beam_size times along batch dimension. """ + self.src = self.src.data.repeat(1, beam_size, 1) + + def map_batch_fn(self, fn): + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v) + else: + struct[k] = fn(v, batch_dim) + + self.src = fn(self.src, 0) + if self.cache is not None: + _recursive_map(self.cache) + + +def gelu(x): + return ( + 0.5 + * x + * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +class PositionwiseFeedForward(nn.Module): + """ A two-layer Feed-Forward-Network with residual layer norm. + + Args: + d_model (int): the size of input for the first-layer of the FFN. + d_ff (int): the hidden layer size of the second-layer + of the FNN. + dropout (float): dropout probability in :math:`[0, 1)`. + """ + + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.actv = gelu + self.dropout_1 = nn.Dropout(dropout) + self.dropout_2 = nn.Dropout(dropout) + + def forward(self, x): + inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) + output = self.dropout_2(self.w_2(inter)) + return output + x + + +# +# TRANSLATOR +# The following code is used to generate summaries using the +# pre-trained weights and beam search. +# + + +def build_predictor(args, tokenizer, symbols, model, logger=None): + # we should be able to refactor the global scorer a lot + scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu") + translator = Translator( + args, model, tokenizer, symbols, global_scorer=scorer, logger=logger + ) + return translator + + +class GNMTGlobalScorer(object): + """ + NMT re-ranking score from + "Google's Neural Machine Translation System" :cite:`wu2016google` + + Args: + alpha (float): length parameter + beta (float): coverage parameter + """ + + def __init__(self, alpha, length_penalty): + self.alpha = alpha + penalty_builder = PenaltyBuilder(length_penalty) + self.length_penalty = penalty_builder.length_penalty() + + def score(self, beam, logprobs): + """ + Rescores a prediction based on penalty functions + """ + normalized_probs = self.length_penalty(beam, logprobs, self.alpha) + return normalized_probs + + +class PenaltyBuilder(object): + """ + Returns the Length and Coverage Penalty function for Beam Search. + + Args: + length_pen (str): option name of length pen + cov_pen (str): option name of cov pen + """ + + def __init__(self, length_pen): + self.length_pen = length_pen + + def length_penalty(self): + if self.length_pen == "wu": + return self.length_wu + elif self.length_pen == "avg": + return self.length_average + else: + return self.length_none + + """ + Below are all the different penalty terms implemented so far + """ + + def length_wu(self, beam, logprobs, alpha=0.0): + """ + NMT length re-ranking score from + "Google's Neural Machine Translation System" :cite:`wu2016google`. + """ + + modifier = ((5 + len(beam.next_ys)) ** alpha) / ((5 + 1) ** alpha) + return logprobs / modifier + + def length_average(self, beam, logprobs, alpha=0.0): + """ + Returns the average probability of tokens in a sequence. + """ + return logprobs / len(beam.next_ys) + + def length_none(self, beam, logprobs, alpha=0.0, beta=0.0): + """ + Returns unmodified scores. + """ + return logprobs + + +class Translator(object): + """ + Uses a model to translate a batch of sentences. + + Args: + model (:obj:`onmt.modules.NMTModel`): + NMT model to use for translation + fields (dict of Fields): data fields + beam_size (int): size of beam to use + n_best (int): number of translations produced + max_length (int): maximum length output to produce + global_scores (:obj:`GlobalScorer`): + object to rescore final translations + copy_attn (bool): use copy attention during translation + cuda (bool): use cuda + beam_trace (bool): trace beam search for debugging + logger(logging.Logger): logger. + """ + + def __init__(self, args, model, vocab, symbols, global_scorer=None, logger=None): + self.logger = logger + self.cuda = args.visible_gpus != "-1" + + self.args = args + self.model = model + self.generator = self.model.generator + self.vocab = vocab + self.symbols = symbols + self.start_token = symbols["BOS"] + self.end_token = symbols["EOS"] + + self.global_scorer = global_scorer + self.beam_size = args.beam_size + self.min_length = args.min_length + self.max_length = args.max_length + + def translate(self, batch, step, attn_debug=False): + """ Generates summaries from one batch of data. + """ + self.model.eval() + with torch.no_grad(): + batch_data = self.translate_batch(batch) + translations = self.from_batch(batch_data) + return translations + + def translate_batch(self, batch, fast=False): + """ + Translate a batch of sentences. + + Mostly a wrapper around :obj:`Beam`. + + Args: + batch (:obj:`Batch`): a batch from a dataset object + data (:obj:`Dataset`): the dataset object + fast (bool): enables fast beam search (may not support all features) + + Todo: + Shouldn't need the original dataset. + """ + with torch.no_grad(): + return self._fast_translate_batch( + batch, self.max_length, min_length=self.min_length + ) + + # Where the beam search lives + # I have no idea why it is being called from the method above + def _fast_translate_batch(self, batch, max_length, min_length=0): + """ Beam Search using the encoder inputs contained in `batch`. + """ + + # The batch object is funny + # Instead of just looking at the size of the arguments we encapsulate + # a size argument. + # Where is it defined? + beam_size = self.beam_size + batch_size = batch.batch_size + src = batch.src + segs = batch.segs + mask_src = batch.mask_src + + src_features = self.model.bert(src, segs, mask_src) + dec_states = self.model.decoder.init_decoder_state( + src, src_features, with_cache=True + ) + device = src_features.device + + # Tile states and memory beam_size times. + dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim)) + src_features = tile(src_features, beam_size, dim=0) + batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) + beam_offset = torch.arange( + 0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device + ) + alive_seq = torch.full( + [batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device + ) + + # Give full probability to the first beam on the first step. + topk_log_probs = torch.tensor( + [0.0] + [float("-inf")] * (beam_size - 1), device=device + ).repeat(batch_size) + + # Structure that holds finished hypotheses. + hypotheses = [[] for _ in range(batch_size)] # noqa: F812 + + results = {} + results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812 + results["scores"] = [[] for _ in range(batch_size)] # noqa: F812 + results["gold_score"] = [0] * batch_size + results["batch"] = batch + + for step in range(max_length): + decoder_input = alive_seq[:, -1].view(1, -1) + + # Decoder forward. + decoder_input = decoder_input.transpose(0, 1) + + dec_out, dec_states = self.model.decoder( + decoder_input, src_features, dec_states, step=step + ) + + # Generator forward. + log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0)) + vocab_size = log_probs.size(-1) + + if step < min_length: + log_probs[:, self.end_token] = -1e20 + + # Multiply probs by the beam probability. + log_probs += topk_log_probs.view(-1).unsqueeze(1) + + alpha = self.global_scorer.alpha + length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha + + # Flatten probs into a list of possibilities. + curr_scores = log_probs / length_penalty + + if self.args.block_trigram: + cur_len = alive_seq.size(1) + if cur_len > 3: + for i in range(alive_seq.size(0)): + fail = False + words = [int(w) for w in alive_seq[i]] + words = [self.vocab.ids_to_tokens[w] for w in words] + words = " ".join(words).replace(" ##", "").split() + if len(words) <= 3: + continue + trigrams = [ + (words[i - 1], words[i], words[i + 1]) + for i in range(1, len(words) - 1) + ] + trigram = tuple(trigrams[-1]) + if trigram in trigrams[:-1]: + fail = True + if fail: + curr_scores[i] = -10e20 + + curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) + topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) + + # Recover log probs. + topk_log_probs = topk_scores * length_penalty + + # Resolve beam origin and true word ids. + topk_beam_index = topk_ids.div(vocab_size) + topk_ids = topk_ids.fmod(vocab_size) + + # Map beam_index to batch_index in the flat representation. + batch_index = topk_beam_index + beam_offset[ + : topk_beam_index.size(0) + ].unsqueeze(1) + select_indices = batch_index.view(-1) + + # Append last prediction. + alive_seq = torch.cat( + [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1 + ) + + is_finished = topk_ids.eq(self.end_token) + if step + 1 == max_length: + is_finished.fill_(1) + # End condition is top beam is finished. + end_condition = is_finished[:, 0].eq(1) + # Save finished hypotheses. + if is_finished.any(): + predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1)) + for i in range(is_finished.size(0)): + b = batch_offset[i] + if end_condition[i]: + is_finished[i].fill_(1) + finished_hyp = is_finished[i].nonzero().view(-1) + # Store finished hypotheses for this batch. + for j in finished_hyp: + hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:])) + # If the batch reached the end, save the n_best hypotheses. + if end_condition[i]: + best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) + score, pred = best_hyp[0] + + results["scores"][b].append(score) + results["predictions"][b].append(pred) + non_finished = end_condition.eq(0).nonzero().view(-1) + # If all sentences are translated, no need to go further. + if len(non_finished) == 0: + break + # Remove finished batches for the next step. + topk_log_probs = topk_log_probs.index_select(0, non_finished) + batch_index = batch_index.index_select(0, non_finished) + batch_offset = batch_offset.index_select(0, non_finished) + alive_seq = predictions.index_select(0, non_finished).view( + -1, alive_seq.size(-1) + ) + # Reorder states. + select_indices = batch_index.view(-1) + src_features = src_features.index_select(0, select_indices) + dec_states.map_batch_fn( + lambda state, dim: state.index_select(dim, select_indices) + ) + + return results + + def from_batch(self, translation_batch): + batch = translation_batch["batch"] + assert len(translation_batch["gold_score"]) == len(translation_batch["predictions"]) + batch_size = batch.batch_size + + preds, _, _, tgt_str, src = ( + translation_batch["predictions"], + translation_batch["scores"], + translation_batch["gold_score"], + batch.tgt_str, + batch.src, + ) + + translations = [] + for b in range(batch_size): + pred_sents = self.vocab.convert_ids_to_tokens([int(n) for n in preds[b][0]]) + pred_sents = " ".join(pred_sents).replace(" ##", "") + gold_sent = " ".join(tgt_str[b].split()) + raw_src = [self.vocab.ids_to_tokens[int(t)] for t in src[b]][:500] + raw_src = " ".join(raw_src) + translation = (pred_sents, gold_sent, raw_src) + translations.append(translation) + + return translations + + def _report_rouge(self, gold_path, can_path): + self.logger.info("Calculating Rouge") + results_dict = test_rouge(self.args.temp_dir, can_path, gold_path) + return results_dict + + +def tile(x, count, dim=0): + """ + Tiles x on dimension dim count times. + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = ( + x.view(batch, -1) + .transpose(0, 1) + .repeat(count, 1) + .transpose(0, 1) + .contiguous() + .view(*out_size) + ) + if dim != 0: + x = x.permute(perm).contiguous() + return x + + +# +# All things ROUGE. Uses `pyrouge` which is a hot mess. +# + + +def test_rouge(temp_dir, cand, ref): + candidates = [line.strip() for line in open(cand, encoding="utf-8")] + references = [line.strip() for line in open(ref, encoding="utf-8")] + print(len(candidates)) + print(len(references)) + assert len(candidates) == len(references) + + cnt = len(candidates) + current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) + if not os.path.isdir(tmp_dir): + os.mkdir(tmp_dir) + os.mkdir(tmp_dir + "/candidate") + os.mkdir(tmp_dir + "/reference") + try: + + for i in range(cnt): + if len(references[i]) < 1: + continue + with open( + tmp_dir + "/candidate/cand.{}.txt".format(i), "w", encoding="utf-8" + ) as f: + f.write(candidates[i]) + with open( + tmp_dir + "/reference/ref.{}.txt".format(i), "w", encoding="utf-8" + ) as f: + f.write(references[i]) + r = pyrouge.Rouge155(temp_dir=temp_dir) + r.model_dir = tmp_dir + "/reference/" + r.system_dir = tmp_dir + "/candidate/" + r.model_filename_pattern = "ref.#ID#.txt" + r.system_filename_pattern = r"cand.(\d+).txt" + rouge_results = r.convert_and_evaluate() + print(rouge_results) + results_dict = r.output_to_dict(rouge_results) + finally: + pass + if os.path.isdir(tmp_dir): + shutil.rmtree(tmp_dir) + return results_dict + + +def rouge_results_to_str(results_dict): + return ">> ROUGE-F(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( + results_dict["rouge_1_f_score"] * 100, + results_dict["rouge_2_f_score"] * 100, + results_dict["rouge_l_f_score"] * 100, + results_dict["rouge_1_recall"] * 100, + results_dict["rouge_2_recall"] * 100, + results_dict["rouge_l_recall"] * 100, + ) + + +class BertSumOptimizer(object): + """ Specific optimizer for BertSum. + + As described in [1], the authors fine-tune BertSum for abstractive + summarization using two Adam Optimizers with different warm-up steps and + learning rate. They also use a custom learning rate scheduler. + + [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." + arXiv preprint arXiv:1908.08345 (2019). + """ + + def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8): + self.encoder = model.encoder + self.decoder = model.decoder + self.lr = lr + self.warmup_steps = warmup_steps + + self.optimizers = { + "encoder": torch.optim.Adam( + model.encoder.parameters(), + lr=lr["encoder"], + betas=(beta_1, beta_2), + eps=eps, + ), + "decoder": torch.optim.Adam( + model.decoder.parameters(), + lr=lr["decoder"], + betas=(beta_1, beta_2), + eps=eps, + ), + } + + self._step = 0 + self.current_learning_rates = {} + + def _update_rate(self, stack): + return self.lr[stack] * min( + self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5) + ) + + def zero_grad(self): + self.optimizer_decoder.zero_grad() + self.optimizer_encoder.zero_grad() + + def step(self): + self._step += 1 + for stack, optimizer in self.optimizers.items(): + new_rate = self._update_rate(stack) + for param_group in optimizer.param_groups: + param_group["lr"] = new_rate + optimizer.step() + self.current_learning_rates[stack] = new_rate diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py new file mode 100644 index 00000000000..e3b974acd97 --- /dev/null +++ b/examples/summarization/run_summarization.py @@ -0,0 +1,271 @@ +import argparse +from collections import namedtuple +import logging +import os +import sys + +import torch +from torch.utils.data import DataLoader, SequentialSampler +from tqdm import tqdm + +from transformers import BertTokenizer + +from modeling_bertabs import BertAbs, build_predictor + +from utils_summarization import ( + SummarizationDataset, + encode_for_summarization, + build_mask, + fit_to_block_size, + compute_token_type_ids, +) + +logger = logging.getLogger(__name__) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + +Batch = namedtuple( + "Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"] +) + + +def evaluate(args): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) + model = bertabs = BertAbs.from_pretrained( + "bertabs-finetuned-{}".format(args.finetuned_model) + ) + bertabs.to(args.device) + bertabs.eval() + + symbols = { + "BOS": tokenizer.vocab["[unused0]"], + "EOS": tokenizer.vocab["[unused1]"], + "PAD": tokenizer.vocab["[PAD]"], + } + + # these (unused) arguments are defined to keep the compatibility + # with the legacy code and will be deleted in a next iteration. + args.result_path = "" + args.temp_dir = "" + + data_iterator = build_data_iterator(args, tokenizer) + predictor = build_predictor(args, tokenizer, symbols, model) + + logger.info("***** Running evaluation *****") + logger.info(" Number examples = %d", len(data_iterator.dataset)) + logger.info(" Batch size = %d", args.batch_size) + logger.info("") + logger.info("***** Beam Search parameters *****") + logger.info(" Beam size = %d", args.beam_size) + logger.info(" Minimum length = %d", args.min_length) + logger.info(" Maximum length = %d", args.max_length) + logger.info(" Alpha (length penalty) = %.2f", args.alpha) + logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT")) + + for batch in tqdm(data_iterator): + batch_data = predictor.translate_batch(batch) + translations = predictor.from_batch(batch_data) + summaries = [format_summary(t) for t in translations] + save_summaries(summaries, args.summaries_output_dir, batch.document_names) + + +def format_summary(translation): + """ Transforms the output of the `from_batch` function + into nicely formatted summaries. + """ + raw_summary, _, _ = translation + summary = ( + raw_summary.replace("[unused0]", "") + .replace("[unused3]", "") + .replace("[PAD]", "") + .replace("[unused1]", "") + .replace(r" +", " ") + .replace(" [unused2] ", ". ") + .replace("[unused2]", "") + .strip() + ) + + return summary + + +def save_summaries(summaries, path, original_document_name): + """ Write the summaries in fies that are prefixed by the original + files' name with the `_summary` appended. + + Attributes: + original_document_names: List[string] + Name of the document that was summarized. + path: string + Path were the summaries will be written + summaries: List[string] + The summaries that we produced. + """ + for summary, document_name in zip(summaries, original_document_name): + # Prepare the summary file's name + if "." in document_name: + bare_document_name = ".".join(document_name.split(".")[:-1]) + extension = document_name.split(".")[-1] + name = bare_document_name + "_summary." + extension + else: + name = document_name + "_summary" + + file_path = os.path.join(path, name) + with open(file_path, "w") as output: + output.write(summary) + + +# +# LOAD the dataset +# + + +def build_data_iterator(args, tokenizer): + dataset = load_and_cache_examples(args, tokenizer) + sampler = SequentialSampler(dataset) + collate_fn = lambda data: collate(data, tokenizer, block_size=512) + iterator = DataLoader( + dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn, + ) + + return iterator + + +def load_and_cache_examples(args, tokenizer): + dataset = SummarizationDataset(args.documents_dir) + return dataset + + +def collate(data, tokenizer, block_size): + """ Collate formats the data passed to the data loader. + + In particular we tokenize the data batch after batch to avoid keeping them + all in memory. We output the data as a namedtuple to fit the original BertAbs's + API. + """ + data = [x for x in data if not len(x[1]) == 0] # remove empty_files + names = [name for name, _, _ in data] + + encoded_text = [ + encode_for_summarization(story, summary, tokenizer) for _, story, summary in data + ] + stories = torch.tensor( + [ + fit_to_block_size(story, block_size, tokenizer.pad_token_id) + for story, _ in encoded_text + ] + ) + encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id) + encoder_mask = build_mask(stories, tokenizer.pad_token_id) + + batch = Batch( + document_names=names, + batch_size=len(stories), + src=stories, + segs=encoder_token_type_ids, + mask_src=encoder_mask, + tgt_str=[""] * len(stories), + ) + + return batch + + +def decode_summary(summary_tokens, tokenizer): + """ Decode the summary and return it in a format + suitable for evaluation. + """ + summary_tokens = summary_tokens.to("cpu").numpy() + summary = tokenizer.decode(summary_tokens) + sentences = summary.split(".") + sentences = [s + "." for s in sentences] + return sentences + + +def main(): + """ The main function defines the interface with the users. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--documents_dir", + default=None, + type=str, + required=True, + help="The folder where the documents to summarize are located.", + ) + parser.add_argument( + "--summaries_output_dir", + default=None, + type=str, + required=True, + help="The folder in wich the summaries should be written.", + ) + # EVALUATION options + parser.add_argument( + "--visible_gpus", + default=-1, + type=int, + help="Number of GPUs with which to do the training.", + ) + parser.add_argument( + "--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.", + ) + # BEAM SEARCH arguments + parser.add_argument( + "--min_length", + default=50, + type=int, + help="Minimum number of tokens for the summaries.", + ) + parser.add_argument( + "--max_length", + default=200, + type=int, + help="Maixmum number of tokens for the summaries.", + ) + parser.add_argument( + "--beam_size", + default=5, + type=int, + help="The number of beams to start with for each example.", + ) + parser.add_argument( + "--alpha", + default=0.95, + type=float, + help="The value of alpha for the length penalty in the beam search.", + ) + parser.add_argument( + "--block_trigram", + default=True, + type=bool, + help="Whether to block the existence of repeating trigrams in the text generated by beam search.", + ) + args = parser.parse_args() + args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda") + + if not documents_dir_is_valid(args.documents_dir): + raise FileNotFoundError( + "We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path." + ) + maybe_create_output_dir(args.summaries_output_dir) + + evaluate(args) + + +def documents_dir_is_valid(path): + if not os.path.exists(path): + return False + + file_list = os.listdir(path) + if len(file_list) == 0: + return False + + return True + + +def maybe_create_output_dir(path): + if not os.path.exists(path): + os.makedirs(path) + + +if __name__ == "__main__": + main() diff --git a/examples/utils_summarization.py b/examples/summarization/utils_summarization.py similarity index 77% rename from examples/utils_summarization.py rename to examples/summarization/utils_summarization.py index 8e95a04e192..e7401b1754d 100644 --- a/examples/utils_summarization.py +++ b/examples/summarization/utils_summarization.py @@ -10,9 +10,14 @@ from torch.utils.data import Dataset # ------------ -class CNNDailyMailDataset(Dataset): +class SummarizationDataset(Dataset): """ Abstracts the dataset used to train seq2seq models. + The class will process the documents that are located in the specified + folder. The preprocessing will work on any document that is reasonably + formatted. On the CNN/DailyMail dataset it will extract both the story + and the summary. + CNN/Daily News: The CNN/Daily News raw datasets are downloaded from [1]. The stories are @@ -25,32 +30,31 @@ class CNNDailyMailDataset(Dataset): [2] https://github.com/abisee/cnn-dailymail/ """ - def __init__(self, data_dir="", prefix="train"): - assert os.path.isdir(data_dir) + def __init__(self, path="", prefix="train"): + """ We initialize the class by listing all the documents to summarize. + Files are not read in memory due to the size of some datasets (like CNN/DailyMail). + """ + assert os.path.isdir(path) - # We initialize the class by listing all the files that contain - # stories and summaries. Files are not read in memory given - # the size of the corpus. - self.stories_path = [] - datasets = ("cnn", "dailymail") - for dataset in datasets: - path_to_stories = os.path.join(data_dir, dataset, "stories") - story_filenames_list = os.listdir(path_to_stories) - for story_filename in story_filenames_list: - path_to_story = os.path.join(path_to_stories, story_filename) - if not os.path.isfile(path_to_story): - continue - self.stories_path.append(path_to_story) + self.documents = [] + story_filenames_list = os.listdir(path) + for story_filename in story_filenames_list: + path_to_story = os.path.join(path, story_filename) + if not os.path.isfile(path_to_story): + continue + self.documents.append(path_to_story) def __len__(self): - return len(self.stories_path) + """ Returns the number of documents. """ + return len(self.documents) def __getitem__(self, idx): - story_path = self.stories_path[idx] - with open(story_path, encoding="utf-8") as source: + document_path = self.documents[idx] + document_name = document_path.split("/")[-1] + with open(document_path, encoding="utf-8") as source: raw_story = source.read() story_lines, summary_lines = process_story(raw_story) - return story_lines, summary_lines + return document_name, story_lines, summary_lines def process_story(raw_story): @@ -80,7 +84,7 @@ def process_story(raw_story): story_lines.append(element) except IndexError: # if "@highlight" is absent from the file we pop - # all elements until there is None. + # all elements until there is None, raising an exception. return story_lines, [] # gather summary lines @@ -114,14 +118,6 @@ def fit_to_block_size(sequence, block_size, pad_token_id): return sequence -def build_lm_labels(sequence, pad_token_id): - """ Padding token are replaced by the value -1 so they - are not taken into account in the loss computation. """ - padded = sequence.clone() - padded[padded == pad_token_id] = -1 - return padded - - def build_mask(sequence, pad_token_id): """ Builds the mask. The attention mechanism will only attend to positions with value 1. """ @@ -165,7 +161,7 @@ def compute_token_type_ids(batch, separator_token_id): """ batch_embeddings = [] for sequence in batch: - sentence_num = 0 + sentence_num = -1 embeddings = [] for s in sequence: if s == separator_token_id: diff --git a/examples/utils_summarization_test.py b/examples/summarization/utils_summarization_test.py similarity index 88% rename from examples/utils_summarization_test.py rename to examples/summarization/utils_summarization_test.py index 1d56ff08031..8bfbf6ab231 100644 --- a/examples/utils_summarization_test.py +++ b/examples/summarization/utils_summarization_test.py @@ -21,7 +21,6 @@ from utils_summarization import ( compute_token_type_ids, fit_to_block_size, build_mask, - build_lm_labels, process_story, ) @@ -88,20 +87,6 @@ class SummarizationDataProcessingTest(unittest.TestCase): expected_summary_lines = ["It was the best of times."] self.assertEqual(expected_summary_lines, summary_lines) - def test_build_lm_labels_no_padding(self): - sequence = torch.tensor([1, 2, 3, 4]) - expected = sequence - np.testing.assert_array_equal( - build_lm_labels(sequence, 0).numpy(), expected.numpy() - ) - - def test_build_lm_labels(self): - sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0]) - expected = torch.tensor([1, 2, 3, 4, -1, -1, -1]) - np.testing.assert_array_equal( - build_lm_labels(sequence, 0).numpy(), expected.numpy() - ) - def test_build_mask_no_padding(self): sequence = torch.tensor([1, 2, 3, 4]) expected = torch.tensor([1, 1, 1, 1]) @@ -125,7 +110,7 @@ class SummarizationDataProcessingTest(unittest.TestCase): [[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]] ) expected = torch.tensor( - [[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]] + [[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]] ) result = compute_token_type_ids(batch, separator) diff --git a/transformers/convert_bertextabs_original_pytorch_checkpoint_to_pytorch.py b/transformers/convert_bertextabs_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 00000000000..4f158966e10 --- /dev/null +++ b/transformers/convert_bertextabs_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Convert BertExtAbs's checkpoints """ + +import argparse +from collections import namedtuple +import logging + +import torch + +from models.model_builder import AbsSummarizer # The authors' implementation + +from transformers import BertConfig, Model2Model, BertModel, BertForMaskedLM + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +BertExtAbsConfig = namedtuple( + "BertExtAbsConfig", + ["temp_dir", "large", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"], +) + + +def convert_bertextabs_checkpoints(path_to_checkpoints, dump_path): + """ Copy/paste and tweak the pre-trained weights provided by the creators + of BertExtAbs for the internal architecture. + """ + + # Load checkpoints in memory + checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage) + + # Instantiate the authors' model with the pre-trained weights + config = BertExtAbsConfig( + temp_dir=".", + finetune_bert=False, + large=False, + share_emb=True, + encoder="bert", + max_pos=512, + enc_layers=6, + enc_hidden_size=512, + enc_heads=8, + enc_ff_size=512, + enc_dropout=0.2, + dec_layers=6, + dec_hidden_size=768, + dec_heads=8, + dec_ff_size=2048, + dec_dropout=0.2, + ) + bertextabs = AbsSummarizer(config, torch.device("cpu"), checkpoints) + bertextabs.eval() + + # Instantiate our version of the model + decoder_config = BertConfig( + hidden_size=config.dec_hidden_size, + num_hidden_layers=config.dec_layers, + num_attention_heads=config.dec_heads, + intermediate_size=config.dec_ff_size, + hidden_dropout_prob=config.dec_dropout, + attention_probs_dropout_prob=config.dec_dropout, + is_decoder=True, + ) + + decoder_model = BertForMaskedLM(decoder_config) + model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder_model) + model.eval() + + # Let us now start the weight copying process + model.encoder.load_state_dict(bertextabs.bert.model.state_dict()) + + # Decoder + + # Embeddings. The positional embeddings are equal to the word embedding plus a modulation + # that is computed at each forward pass. This may be a source of discrepancy. + model.decoder.bert.embeddings.word_embeddings.weight = bertextabs.decoder.embeddings.weight + model.decoder.bert.embeddings.position_embeddings.weight = bertextabs.decoder.embeddings.weight + model.decoder.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like(bertextabs.decoder.embeddings.weight) # not defined for BertExtAbs decoder + + # In the original code the LayerNorms are applied twice in the layers, at the beginning and between the + # attention layers. + model.decoder.bert.embeddings.LayerNorm.weight = bertextabs.decoder.transformer_layers[0].layer_norm_1.weight + + for i in range(config.dec_layers): + + # self attention + model.decoder.bert.encoder.layer[i].attention.self.query.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_query.weight + model.decoder.bert.encoder.layer[i].attention.self.key.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_keys.weight + model.decoder.bert.encoder.layer[i].attention.self.value.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_values.weight + model.decoder.bert.encoder.layer[i].attention.output.dense.weight = bertextabs.decoder.transformer_layers[i].self_attn.final_linear.weight + model.decoder.bert.encoder.layer[i].attention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].layer_norm_2.weight + + # attention + model.decoder.bert.encoder.layer[i].crossattention.self.query.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_query.weight + model.decoder.bert.encoder.layer[i].crossattention.self.key.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_keys.weight + model.decoder.bert.encoder.layer[i].crossattention.self.value.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_values.weight + model.decoder.bert.encoder.layer[i].crossattention.output.dense.weight = bertextabs.decoder.transformer_layers[i].context_attn.final_linear.weight + model.decoder.bert.encoder.layer[i].crossattention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].feed_forward.layer_norm.weight + + # intermediate + model.decoder.bert.encoder.layer[i].intermediate.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_1.weight + + # output + model.decoder.bert.encoder.layer[i].output.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_2.weight + + try: + model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i + 1].layer_norm_1.weight + except IndexError: + model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.layer_norm.weight + + # LM Head + """ + model.decoder.cls.predictions.transform.dense.weight + model.decoder.cls.predictions.transform.dense.biais + model.decoder.cls.predictions.transform.LayerNorm.weight + model.decoder.cls.predictions.transform.LayerNorm.biais + model.decoder.cls.predictions.decoder.weight + model.decoder.cls.predictions.decoder.biais + model.decoder.cls.predictions.biais.data + """ + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--bertextabs_checkpoint_path", + default=None, + type=str, + required=True, + help="Path the official PyTorch dump.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model.", + ) + args = parser.parse_args() + + convert_bertextabs_checkpoints( + args.bertextabs_checkpoint_path, + args.pytorch_dump_folder_path, + ) diff --git a/transformers/generate/beam_search.py b/transformers/generate/beam_search.py index abe31860498..b56ebbabb86 100644 --- a/transformers/generate/beam_search.py +++ b/transformers/generate/beam_search.py @@ -25,7 +25,6 @@ Use Beam Search to generate sequences using encoder-decoder models. """ import torch from torch import nn - import logging @@ -45,6 +44,7 @@ class BeamSearch(object): max_length, alpha=0, block_repeating_trigrams=True, + device=torch.device("cpu"), ): r""" Inputs: @@ -156,18 +156,24 @@ class BeamSearch(object): kwargs_decoder["encoder_hidden_states"] = tile( encoder_hidden_states, self.beam_size, dim=0 ) - kwargs_decoder["encoder_attention_mask"] = tile( - kwargs_encoder["attention_mask"], self.beam_size, dim=0 + try: + kwargs_decoder["encoder_attention_mask"] = tile( + kwargs_encoder["attention_mask"], self.beam_size, dim=0 + ) + except: + pass + kwargs_decoder["state"].src = tile( + kwargs_decoder["state"].src, self.beam_size, dim=0 ) # grow the beam iteratively batch_size, block_size = encoder_input_ids.size() self._init_beam_state(batch_size) for step in range(self.max_length): - decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id) kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id) - outputs = self.model.decoder(decoder_input, **kwargs_decoder) + + outputs, state = self.model.decoder(decoder_input, **kwargs_decoder) next_token_scores = outputs[0][:, -1, :].squeeze(1) log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0) @@ -178,9 +184,13 @@ class BeamSearch(object): kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[ "encoder_hidden_states" ].index_select(0, surviving_beams_rows) - kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[ - "encoder_attention_mask" - ].index_select(0, surviving_beams_rows) + try: + kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[ + "encoder_attention_mask" + ].index_select(0, surviving_beams_rows) + except: + pass + kwargs_decoder["state"] = state return self.results