give transformers API to BertAbs

This commit is contained in:
Rémi Louf 2019-11-23 00:18:44 +01:00 committed by Julien Chaumond
parent 4d18199902
commit 2403a66598
9 changed files with 2188 additions and 54 deletions

View File

@ -0,0 +1,161 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints """
import argparse
from collections import namedtuple
import logging
import pdb
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from model_bertabs import BertAbsSummarizer
from transformers import BertTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
BertAbsConfig = namedtuple(
"BertAbsConfig",
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertAbs for the internal architecture.
"""
# Instantiate the authors' model with the pre-trained weights
config = BertAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
use_bert_emb=False,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
original.eval()
new_model = BertAbsSummarizer(config, torch.device("cpu"))
new_model.eval()
# -------------------
# Convert the weights
# -------------------
logging.info("convert the model")
new_model.encoder.load_state_dict(original.bert.state_dict())
new_model.decoder.generator.load_state_dict(original.generator.state_dict())
new_model.decoder.embeddings.load_state_dict(original.decoder.embeddings.state_dict())
new_model.decoder.pos_emb.load_state_dict(original.decoder.pos_emb.state_dict())
new_model.decoder.transformer_layers.load_state_dict(original.decoder.transformer_layers.state_dict())
new_model.decoder.layer_norm.load_state_dict(original.decoder.layer_norm.state_dict())
# ----------------------------------
# Make sure the outpus are identical
# ----------------------------------
logging.info("Make sure that the models' outputs are identical")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# prepare the model inputs
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
# failsafe to make sure the weights reset does not affect the
# loaded weights.
assert torch.max(torch.abs(original.generator[0].weight - new_model.decoder.generator[0].weight)) == 0
# forward pass
src = encoder_input_ids
tgt = decoder_input_ids
segs = token_type_ids = None
clss = None
mask_src = encoder_attention_mask = None
mask_tgt = decoder_attention_mask = None
mask_cls = None
# The original model does not apply the geneator layer immediatly but rather in
# the beam search (where it combines softmax + linear layer). Since we already
# apply the softmax in our generation process we only apply the linear layer here.
# We make sure that the outputs of the full stack are identical
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
output_original_model = original.generator(output_original_model)
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
output_converted_model = torch.nn.functional.log_softmax(output_converted_model, dim=-1)
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
if are_identical:
logging.info("all weights are equal up to 1e-3")
else:
raise ValueError("the weights are different. The new model is likely different from the original one.")
# The model has been saved with torch.save(model) and this is bound to the exact
# directory structure. We save the state_dict instead.
logging.info("saving the model's state dictionary")
torch.save(new_model.state_dict(), "bert-ext-abs.pt")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertabs_checkpoints(
args.bertabs_checkpoint_path,
args.pytorch_dump_folder_path,
)

View File

@ -0,0 +1,141 @@
# coding=utf-8
# Copyright 2019 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BertAbs configuration """
import json
import logging
import sys
from transformers import PretrainedConfig
logger = logging.getLogger(__name__)
BERTABS_FINETUNED_CONFIG_MAP = {
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json",
}
class BertAbsConfig(PretrainedConfig):
r""" Class to store the configuration of the BertAbs model.
Arguments:
temp_dir: string
Unused in the current situation. Kept for compatibility but will be removed.
finetune_bert: bool
Whether to fine-tune the model or not. Will be kept for reference
in case we want to add the possibility to fine-tune the model.
large: bool
Whether to use bert-large as a base.
share_emb: book
Whether the embeddings are shared between the encoder and decoder.
encoder: string
Not clear what this does. Leave to "bert" for pre-trained weights.
max_pos: int
The maximum sequence length that this model will be used with.
enc_layer: int
The numner of hidden layers in the Transformer encoder.
enc_hidden_size: int
The size of the encoder's layers.
enc_heads: int
The number of attention heads for each attention layer in the encoder.
enc_ff_size: int
The size of the encoder's feed-forward layers.
enc_dropout: int
The dropout probabilitiy for all fully connected layers in the
embeddings, layers, pooler and also the attention probabilities in
the encoder.
dec_layer: int
The numner of hidden layers in the decoder.
dec_hidden_size: int
The size of the decoder's layers.
dec_heads: int
The number of attention heads for each attention layer in the decoder.
dec_ff_size: int
The size of the decoder's feed-forward layers.
dec_dropout: int
The dropout probabilitiy for all fully connected layers in the
embeddings, layers, pooler and also the attention probabilities in
the decoder.
"""
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
def __init__(
self,
vocab_size_or_config_json_file=30522,
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
**kwargs,
):
super(BertAbsConfig, self).__init__(**kwargs)
if self._input_is_path_to_json(vocab_size_or_config_json_file):
path_to_json = vocab_size_or_config_json_file
with open(path_to_json, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.temp_dir = temp_dir
self.finetune_bert = finetune_bert
self.large = large
self.vocab_size = vocab_size_or_config_json_file
self.max_pos = max_pos
self.encoder = encoder
self.enc_layers = enc_layers
self.enc_hidden_size = enc_hidden_size
self.enc_heads = enc_heads
self.enc_ff_size = enc_ff_size
self.enc_dropout = enc_dropout
self.share_emb = share_emb
self.dec_layers = dec_layers
self.dec_hidden_size = dec_hidden_size
self.dec_heads = dec_heads
self.dec_ff_size = dec_ff_size
self.dec_dropout = dec_dropout
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
def _input_is_path_to_json(self, first_argument):
""" Checks whether the first argument passed to config
is the path to a JSON file that contains the config.
"""
is_python_2 = sys.version_info[0] == 2
if is_python_2:
return isinstance(first_argument, unicode)
else:
return isinstance(first_argument, str)

View File

@ -0,0 +1,162 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints
The file currently does not do much as we ended up copying the exact model
structure, but I leave it here in case we ever want to refactor the model.
"""
import argparse
from collections import namedtuple
import logging
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from model_bertabs import BertAbsSummarizer
from transformers import BertTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
BertAbsConfig = namedtuple(
"BertAbsConfig",
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertAbs for the internal architecture.
"""
# Instantiate the authors' model with the pre-trained weights
config = BertAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
use_bert_emb=False,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
original.eval()
new_model = BertAbsSummarizer(config, torch.device("cpu"))
new_model.eval()
# -------------------
# Convert the weights
# -------------------
logging.info("convert the model")
new_model.bert.load_state_dict(original.bert.state_dict())
new_model.decoder.load_state_dict(original.decoder.state_dict())
new_model.generator.load_state_dict(original.generator.state_dict())
# ----------------------------------
# Make sure the outpus are identical
# ----------------------------------
logging.info("Make sure that the models' outputs are identical")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# prepare the model inputs
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
# failsafe to make sure the weights reset does not affect the
# loaded weights.
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
# forward pass
src = encoder_input_ids
tgt = decoder_input_ids
segs = token_type_ids = None
clss = None
mask_src = encoder_attention_mask = None
mask_tgt = decoder_attention_mask = None
mask_cls = None
# The original model does not apply the geneator layer immediatly but rather in
# the beam search (where it combines softmax + linear layer). Since we already
# apply the softmax in our generation process we only apply the linear layer here.
# We make sure that the outputs of the full stack are identical
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
output_original_generator = original.generator(output_original_model)
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
output_converted_generator = new_model.generator(output_converted_model)
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
if are_identical:
logging.info("all weights are equal up to 1e-3")
else:
raise ValueError("the weights are different. The new model is likely different from the original one.")
# The model has been saved with torch.save(model) and this is bound to the exact
# directory structure. We save the state_dict instead.
logging.info("saving the model's state dictionary")
torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertabs_checkpoints(
args.bertabs_checkpoint_path,
args.pytorch_dump_folder_path,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,271 @@
import argparse
from collections import namedtuple
import logging
import os
import sys
import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from transformers import BertTokenizer
from modeling_bertabs import BertAbs, build_predictor
from utils_summarization import (
SummarizationDataset,
encode_for_summarization,
build_mask,
fit_to_block_size,
compute_token_type_ids,
)
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Batch = namedtuple(
"Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]
)
def evaluate(args):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
model = bertabs = BertAbs.from_pretrained(
"bertabs-finetuned-{}".format(args.finetuned_model)
)
bertabs.to(args.device)
bertabs.eval()
symbols = {
"BOS": tokenizer.vocab["[unused0]"],
"EOS": tokenizer.vocab["[unused1]"],
"PAD": tokenizer.vocab["[PAD]"],
}
# these (unused) arguments are defined to keep the compatibility
# with the legacy code and will be deleted in a next iteration.
args.result_path = ""
args.temp_dir = ""
data_iterator = build_data_iterator(args, tokenizer)
predictor = build_predictor(args, tokenizer, symbols, model)
logger.info("***** Running evaluation *****")
logger.info(" Number examples = %d", len(data_iterator.dataset))
logger.info(" Batch size = %d", args.batch_size)
logger.info("")
logger.info("***** Beam Search parameters *****")
logger.info(" Beam size = %d", args.beam_size)
logger.info(" Minimum length = %d", args.min_length)
logger.info(" Maximum length = %d", args.max_length)
logger.info(" Alpha (length penalty) = %.2f", args.alpha)
logger.info(" Trigrams %s be blocked", ("will" if args.block_trigram else "will NOT"))
for batch in tqdm(data_iterator):
batch_data = predictor.translate_batch(batch)
translations = predictor.from_batch(batch_data)
summaries = [format_summary(t) for t in translations]
save_summaries(summaries, args.summaries_output_dir, batch.document_names)
def format_summary(translation):
""" Transforms the output of the `from_batch` function
into nicely formatted summaries.
"""
raw_summary, _, _ = translation
summary = (
raw_summary.replace("[unused0]", "")
.replace("[unused3]", "")
.replace("[PAD]", "")
.replace("[unused1]", "")
.replace(r" +", " ")
.replace(" [unused2] ", ". ")
.replace("[unused2]", "")
.strip()
)
return summary
def save_summaries(summaries, path, original_document_name):
""" Write the summaries in fies that are prefixed by the original
files' name with the `_summary` appended.
Attributes:
original_document_names: List[string]
Name of the document that was summarized.
path: string
Path were the summaries will be written
summaries: List[string]
The summaries that we produced.
"""
for summary, document_name in zip(summaries, original_document_name):
# Prepare the summary file's name
if "." in document_name:
bare_document_name = ".".join(document_name.split(".")[:-1])
extension = document_name.split(".")[-1]
name = bare_document_name + "_summary." + extension
else:
name = document_name + "_summary"
file_path = os.path.join(path, name)
with open(file_path, "w") as output:
output.write(summary)
#
# LOAD the dataset
#
def build_data_iterator(args, tokenizer):
dataset = load_and_cache_examples(args, tokenizer)
sampler = SequentialSampler(dataset)
collate_fn = lambda data: collate(data, tokenizer, block_size=512)
iterator = DataLoader(
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
)
return iterator
def load_and_cache_examples(args, tokenizer):
dataset = SummarizationDataset(args.documents_dir)
return dataset
def collate(data, tokenizer, block_size):
""" Collate formats the data passed to the data loader.
In particular we tokenize the data batch after batch to avoid keeping them
all in memory. We output the data as a namedtuple to fit the original BertAbs's
API.
"""
data = [x for x in data if not len(x[1]) == 0] # remove empty_files
names = [name for name, _, _ in data]
encoded_text = [
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
]
stories = torch.tensor(
[
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
for story, _ in encoded_text
]
)
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
batch = Batch(
document_names=names,
batch_size=len(stories),
src=stories,
segs=encoder_token_type_ids,
mask_src=encoder_mask,
tgt_str=[""] * len(stories),
)
return batch
def decode_summary(summary_tokens, tokenizer):
""" Decode the summary and return it in a format
suitable for evaluation.
"""
summary_tokens = summary_tokens.to("cpu").numpy()
summary = tokenizer.decode(summary_tokens)
sentences = summary.split(".")
sentences = [s + "." for s in sentences]
return sentences
def main():
""" The main function defines the interface with the users.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--documents_dir",
default=None,
type=str,
required=True,
help="The folder where the documents to summarize are located.",
)
parser.add_argument(
"--summaries_output_dir",
default=None,
type=str,
required=True,
help="The folder in wich the summaries should be written.",
)
# EVALUATION options
parser.add_argument(
"--visible_gpus",
default=-1,
type=int,
help="Number of GPUs with which to do the training.",
)
parser.add_argument(
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
)
# BEAM SEARCH arguments
parser.add_argument(
"--min_length",
default=50,
type=int,
help="Minimum number of tokens for the summaries.",
)
parser.add_argument(
"--max_length",
default=200,
type=int,
help="Maixmum number of tokens for the summaries.",
)
parser.add_argument(
"--beam_size",
default=5,
type=int,
help="The number of beams to start with for each example.",
)
parser.add_argument(
"--alpha",
default=0.95,
type=float,
help="The value of alpha for the length penalty in the beam search.",
)
parser.add_argument(
"--block_trigram",
default=True,
type=bool,
help="Whether to block the existence of repeating trigrams in the text generated by beam search.",
)
args = parser.parse_args()
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
if not documents_dir_is_valid(args.documents_dir):
raise FileNotFoundError(
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
)
maybe_create_output_dir(args.summaries_output_dir)
evaluate(args)
def documents_dir_is_valid(path):
if not os.path.exists(path):
return False
file_list = os.listdir(path)
if len(file_list) == 0:
return False
return True
def maybe_create_output_dir(path):
if not os.path.exists(path):
os.makedirs(path)
if __name__ == "__main__":
main()

View File

@ -10,9 +10,14 @@ from torch.utils.data import Dataset
# ------------
class CNNDailyMailDataset(Dataset):
class SummarizationDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
The class will process the documents that are located in the specified
folder. The preprocessing will work on any document that is reasonably
formatted. On the CNN/DailyMail dataset it will extract both the story
and the summary.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
@ -25,32 +30,31 @@ class CNNDailyMailDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, data_dir="", prefix="train"):
assert os.path.isdir(data_dir)
def __init__(self, path="", prefix="train"):
""" We initialize the class by listing all the documents to summarize.
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
"""
assert os.path.isdir(path)
# We initialize the class by listing all the files that contain
# stories and summaries. Files are not read in memory given
# the size of the corpus.
self.stories_path = []
datasets = ("cnn", "dailymail")
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
story_filenames_list = os.listdir(path_to_stories)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, story_filename)
if not os.path.isfile(path_to_story):
continue
self.stories_path.append(path_to_story)
self.documents = []
story_filenames_list = os.listdir(path)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path, story_filename)
if not os.path.isfile(path_to_story):
continue
self.documents.append(path_to_story)
def __len__(self):
return len(self.stories_path)
""" Returns the number of documents. """
return len(self.documents)
def __getitem__(self, idx):
story_path = self.stories_path[idx]
with open(story_path, encoding="utf-8") as source:
document_path = self.documents[idx]
document_name = document_path.split("/")[-1]
with open(document_path, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
return story_lines, summary_lines
return document_name, story_lines, summary_lines
def process_story(raw_story):
@ -80,7 +84,7 @@ def process_story(raw_story):
story_lines.append(element)
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
# all elements until there is None, raising an exception.
return story_lines, []
# gather summary lines
@ -114,14 +118,6 @@ def fit_to_block_size(sequence, block_size, pad_token_id):
return sequence
def build_lm_labels(sequence, pad_token_id):
""" Padding token are replaced by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token_id] = -1
return padded
def build_mask(sequence, pad_token_id):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
@ -165,7 +161,7 @@ def compute_token_type_ids(batch, separator_token_id):
"""
batch_embeddings = []
for sequence in batch:
sentence_num = 0
sentence_num = -1
embeddings = []
for s in sequence:
if s == separator_token_id:

View File

@ -21,7 +21,6 @@ from utils_summarization import (
compute_token_type_ids,
fit_to_block_size,
build_mask,
build_lm_labels,
process_story,
)
@ -88,20 +87,6 @@ class SummarizationDataProcessingTest(unittest.TestCase):
expected_summary_lines = ["It was the best of times."]
self.assertEqual(expected_summary_lines, summary_lines)
def test_build_lm_labels_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = sequence
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_lm_labels(self):
sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0])
expected = torch.tensor([1, 2, 3, 4, -1, -1, -1])
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1])
@ -125,7 +110,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
)
expected = torch.tensor(
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]]
[[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
)
result = compute_token_type_ids(batch, separator)

View File

@ -0,0 +1,158 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints """
import argparse
from collections import namedtuple
import logging
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from transformers import BertConfig, Model2Model, BertModel, BertForMaskedLM
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BertExtAbsConfig = namedtuple(
"BertExtAbsConfig",
["temp_dir", "large", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertextabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertExtAbs for the internal architecture.
"""
# Load checkpoints in memory
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
# Instantiate the authors' model with the pre-trained weights
config = BertExtAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
bertextabs = AbsSummarizer(config, torch.device("cpu"), checkpoints)
bertextabs.eval()
# Instantiate our version of the model
decoder_config = BertConfig(
hidden_size=config.dec_hidden_size,
num_hidden_layers=config.dec_layers,
num_attention_heads=config.dec_heads,
intermediate_size=config.dec_ff_size,
hidden_dropout_prob=config.dec_dropout,
attention_probs_dropout_prob=config.dec_dropout,
is_decoder=True,
)
decoder_model = BertForMaskedLM(decoder_config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder_model)
model.eval()
# Let us now start the weight copying process
model.encoder.load_state_dict(bertextabs.bert.model.state_dict())
# Decoder
# Embeddings. The positional embeddings are equal to the word embedding plus a modulation
# that is computed at each forward pass. This may be a source of discrepancy.
model.decoder.bert.embeddings.word_embeddings.weight = bertextabs.decoder.embeddings.weight
model.decoder.bert.embeddings.position_embeddings.weight = bertextabs.decoder.embeddings.weight
model.decoder.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like(bertextabs.decoder.embeddings.weight) # not defined for BertExtAbs decoder
# In the original code the LayerNorms are applied twice in the layers, at the beginning and between the
# attention layers.
model.decoder.bert.embeddings.LayerNorm.weight = bertextabs.decoder.transformer_layers[0].layer_norm_1.weight
for i in range(config.dec_layers):
# self attention
model.decoder.bert.encoder.layer[i].attention.self.query.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_query.weight
model.decoder.bert.encoder.layer[i].attention.self.key.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_keys.weight
model.decoder.bert.encoder.layer[i].attention.self.value.weight = bertextabs.decoder.transformer_layers[i].self_attn.linear_values.weight
model.decoder.bert.encoder.layer[i].attention.output.dense.weight = bertextabs.decoder.transformer_layers[i].self_attn.final_linear.weight
model.decoder.bert.encoder.layer[i].attention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].layer_norm_2.weight
# attention
model.decoder.bert.encoder.layer[i].crossattention.self.query.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_query.weight
model.decoder.bert.encoder.layer[i].crossattention.self.key.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_keys.weight
model.decoder.bert.encoder.layer[i].crossattention.self.value.weight = bertextabs.decoder.transformer_layers[i].context_attn.linear_values.weight
model.decoder.bert.encoder.layer[i].crossattention.output.dense.weight = bertextabs.decoder.transformer_layers[i].context_attn.final_linear.weight
model.decoder.bert.encoder.layer[i].crossattention.output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i].feed_forward.layer_norm.weight
# intermediate
model.decoder.bert.encoder.layer[i].intermediate.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_1.weight
# output
model.decoder.bert.encoder.layer[i].output.dense.weight = bertextabs.decoder.transformer_layers[i].feed_forward.w_2.weight
try:
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.transformer_layers[i + 1].layer_norm_1.weight
except IndexError:
model.decoder.bert.encoder.layer[i].output.LayerNorm.weight = bertextabs.decoder.layer_norm.weight
# LM Head
"""
model.decoder.cls.predictions.transform.dense.weight
model.decoder.cls.predictions.transform.dense.biais
model.decoder.cls.predictions.transform.LayerNorm.weight
model.decoder.cls.predictions.transform.LayerNorm.biais
model.decoder.cls.predictions.decoder.weight
model.decoder.cls.predictions.decoder.biais
model.decoder.cls.predictions.biais.data
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertextabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertextabs_checkpoints(
args.bertextabs_checkpoint_path,
args.pytorch_dump_folder_path,
)

View File

@ -25,7 +25,6 @@ Use Beam Search to generate sequences using encoder-decoder models.
"""
import torch
from torch import nn
import logging
@ -45,6 +44,7 @@ class BeamSearch(object):
max_length,
alpha=0,
block_repeating_trigrams=True,
device=torch.device("cpu"),
):
r"""
Inputs:
@ -156,18 +156,24 @@ class BeamSearch(object):
kwargs_decoder["encoder_hidden_states"] = tile(
encoder_hidden_states, self.beam_size, dim=0
)
kwargs_decoder["encoder_attention_mask"] = tile(
kwargs_encoder["attention_mask"], self.beam_size, dim=0
try:
kwargs_decoder["encoder_attention_mask"] = tile(
kwargs_encoder["attention_mask"], self.beam_size, dim=0
)
except:
pass
kwargs_decoder["state"].src = tile(
kwargs_decoder["state"].src, self.beam_size, dim=0
)
# grow the beam iteratively
batch_size, block_size = encoder_input_ids.size()
self._init_beam_state(batch_size)
for step in range(self.max_length):
decoder_input = fit_to_block_size(self.growing_beams, block_size, self.pad_token_id)
kwargs_decoder["attention_mask"] = build_mask(decoder_input, self.pad_token_id)
outputs = self.model.decoder(decoder_input, **kwargs_decoder)
outputs, state = self.model.decoder(decoder_input, **kwargs_decoder)
next_token_scores = outputs[0][:, -1, :].squeeze(1)
log_probabilities = torch.nn.functional.log_softmax(next_token_scores, dim=0)
@ -178,9 +184,13 @@ class BeamSearch(object):
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states"
].index_select(0, surviving_beams_rows)
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
"encoder_attention_mask"
].index_select(0, surviving_beams_rows)
try:
kwargs_decoder["encoder_attention_mask"] = kwargs_decoder[
"encoder_attention_mask"
].index_select(0, surviving_beams_rows)
except:
pass
kwargs_decoder["state"] = state
return self.results