resolve PR comments

This commit is contained in:
Rémi Louf 2019-10-29 17:10:20 +01:00
parent 4c3ac4a7d8
commit dfce409691
7 changed files with 647 additions and 473 deletions

View File

@ -16,10 +16,9 @@
""" Finetuning seq2seq models for sequence generation."""
import argparse
from collections import deque
import functools
import logging
import os
import pickle
import random
import sys
@ -29,7 +28,22 @@ import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, PreTrainedSeq2seq, Model2Model
from transformers import (
AutoTokenizer,
BertForMaskedLM,
BertConfig,
PreTrainedSeq2seq,
Model2Model,
)
from utils_summarization import (
CNNDailyMailDataset,
encode_for_summarization,
fit_to_block_size,
build_lm_labels,
build_mask,
compute_token_type_ids,
)
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@ -46,194 +60,41 @@ def set_seed(args):
# ------------
class TextDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
assert os.path.isdir(data_dir)
# Load the features that have already been computed, if any
cached_features_file = os.path.join(
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
)
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as source:
self.examples = pickle.load(source)
return
logger.info("Creating features from dataset at %s", data_dir)
datasets = ["cnn", "dailymail"]
self.examples = {"source": [], "target": []}
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
story_filenames_list = os.listdir(path_to_stories)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, story_filename)
if not os.path.isfile(path_to_story):
continue
with open(path_to_story, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
if len(summary_lines) == 0 or len(story_lines) == 0:
continue
story_token_ids, summary_token_ids = _encode_for_summarization(
story_lines, summary_lines, tokenizer
)
story_seq = _fit_to_block_size(story_token_ids, block_size)
self.examples["source"].append(story_seq)
summary_seq = _fit_to_block_size(summary_token_ids, block_size)
self.examples["summary"].append(summary_seq)
logger.info("Saving features into cache file %s", cached_features_file)
with open(cached_features_file, "wb") as sink:
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.examples)
def __getitem__(self, items):
return (
torch.tensor(self.examples["source"][items]),
torch.tensor(self.examples["target"][items]),
)
def process_story(raw_story):
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
nonempty_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
# gather article lines
story_lines = []
lines = deque(nonempty_lines)
while True:
try:
element = lines.popleft()
if element.startswith("@highlight"):
break
story_lines.append(element)
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return story_lines, []
# gather summary lines
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
return story_lines, summary_lines
def _encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in story_lines
]
summary_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in summary_lines
]
story_token_ids = [
token for sentence in story_lines_token_ids for token in sentence
]
summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence
]
return story_token_ids, summary_token_ids
def _add_missing_period(line):
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
return line
return line + "."
def _fit_to_block_size(sequence, block_size):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([0] * (block_size - len(sequence)))
return sequence
def mask_padding_tokens(sequence):
""" Padding token, encoded as 0, are represented by the value -1 in the
masks """
padded = sequence.clone()
padded[padded == 0] = -1
return padded
def load_and_cache_examples(args, tokenizer):
dataset = TextDataset(tokenizer, data_dir=args.data_dir)
dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir)
return dataset
def compute_token_type_ids(batch, separator_token_id):
""" Segment embeddings as described in [1]
def collate(data, tokenizer, block_size):
""" List of tuple as an input. """
# remove the files with empty an story/summary, encode and fit to block
data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
data = [
encode_for_summarization(story, summary, tokenizer) for story, summary in data
]
data = [
(
fit_to_block_size(story, block_size, tokenizer.pad_token_id),
fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
)
for story, summary in data
]
The values {0,1} were found in the repository [2].
stories = torch.tensor([story for story, summary in data])
summaries = torch.tensor([summary for story, summary in data])
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings = []
sentence_num = 0
for sequence in batch:
embeddings = []
for s in sequence:
if s == separator_token_id:
sentence_num += 1
embeddings.append(sentence_num % 2)
batch_embeddings.append(embeddings)
return torch.tensor(batch_embeddings)
return (
stories,
summaries,
encoder_token_type_ids,
encoder_mask,
decoder_mask,
lm_labels,
)
# ----------
@ -252,7 +113,7 @@ class BertSumOptimizer(object):
arXiv preprint arXiv:1908.08345 (2019).
"""
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-9):
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
self.encoder = model.encoder
self.decoder = model.decoder
self.lr = lr
@ -306,8 +167,12 @@ def train(args, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_dataset = load_and_cache_examples(args, tokenizer)
train_sampler = RandomSampler(train_dataset)
model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512)
train_dataloader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size
train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=model_collate_fn,
)
# Training schedule
@ -351,26 +216,23 @@ def train(args, model, tokenizer):
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
for step, batch in enumerate(epoch_iterator):
source, target = batch
token_type_ids = compute_token_type_ids(source, tokenizer.cls_token_id)
labels_src = mask_padding_tokens(source)
labels_tgt = mask_padding_tokens(target)
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
token_type_ids = token_type_ids.to(args.device)
labels_src = labels_src.to(args.device)
labels_tgt = labels_tgt.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
model.train()
outputs = model(
source,
target,
token_type_ids=token_type_ids,
decoder_encoder_attention_mask=labels_src,
decoder_attention_mask=labels_tgt,
decoder_lm_labels=labels_tgt,
decoder_initialize_randomly=True,
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
)
loss = outputs[0]
@ -421,21 +283,23 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target = batch
labels_src = mask_padding_tokens(source)
labels_tgt = mask_padding_tokens(target)
source.to(args.device)
target.to(args.device)
labels_src.to(args.device)
labels_tgt.to(args.device)
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
with torch.no_grad():
outputs = model(
source,
target,
decoder_encoder_attention_mask=labels_src,
decoder_attention_mask=labels_tgt,
decoder_lm_labels=labels_tgt,
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()
@ -525,7 +389,7 @@ def main():
)
parser.add_argument(
"--num_train_epochs",
default=1,
default=10,
type=int,
help="Total number of training epochs to perform.",
)
@ -558,9 +422,13 @@ def main():
args.device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count()
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = Model2Model.from_pretrained(args.model_name_or_path)
config = BertConfig.from_pretrained(args.model_name_or_path)
decoder_model = BertForMaskedLM(config)
model = Model2Model.from_pretrained(
args.model_name_or_path, decoder_model=decoder_model
)
# Setup logging
logging.basicConfig(

View File

@ -1,76 +0,0 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from run_summarization_finetuning import _fit_to_block_size, process_story
class DataLoaderTest(unittest.TestCase):
def setUp(self):
self.block_size = 10
def test_truncate_sequence_too_small(self):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence = [1, 2, 3, 4]
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
def test_truncate_sequence_fit_exactly(self):
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
def test_truncate_sequence_too_big(self):
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
def test_process_story_no_highlights(self):
""" Processing a story with no highlights should raise an exception.
"""
raw_story = """It was the year of Our Lord one thousand seven hundred and
seventy-five.\n\nSpiritual revelations were conceded to England at that
favoured period, as at this."""
_, summary = process_story(raw_story)
self.assertEqual(summary, [])
def test_process_empty_story(self):
""" An empty story should also raise and exception.
"""
raw_story = ""
story, summary = process_story(raw_story)
self.assertEqual(story, [])
self.assertEqual(summary, [])
def test_story_with_missing_period(self):
raw_story = (
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five\n\nSpiritual revelations were conceded to England "
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
)
story_lines, summary_lines = process_story(raw_story)
expected_story_lines = [
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
"Spiritual revelations were conceded to England at that favoured period, as at this.",
]
self.assertEqual(expected_story_lines, story_lines)
expected_summary_lines = ["It was the best of times."]
self.assertEqual(expected_summary_lines, summary_lines)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,184 @@
from collections import deque
import os
import torch
from torch.utils.data import Dataset
# ------------
# Data loading
# ------------
class CNNDailyMailDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, tokenizer, prefix="train", data_dir=""):
assert os.path.isdir(data_dir)
self.tokenizer = tokenizer
# We initialize the class by listing all the files that contain
# stories and summaries. Files are not read in memory given
# the size of the corpus.
self.stories_path = []
datasets = ("cnn", "dailymail")
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
story_filenames_list = os.listdir(path_to_stories)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, story_filename)
if not os.path.isfile(path_to_story):
continue
self.stories_path.append(path_to_story)
def __len__(self):
return len(self.stories_path)
def __getitem__(self, idx):
story_path = self.stories_path[idx]
with open(story_path, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
return story_lines, summary_lines
def process_story(raw_story):
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
nonempty_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
# gather article lines
story_lines = []
lines = deque(nonempty_lines)
while True:
try:
element = lines.popleft()
if element.startswith("@highlight"):
break
story_lines.append(element)
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return story_lines, []
# gather summary lines
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
return story_lines, summary_lines
def _add_missing_period(line):
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
return line
return line + "."
# --------------------------
# Encoding and preprocessing
# --------------------------
def fit_to_block_size(sequence, block_size, pad_token):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([pad_token] * (block_size - len(sequence)))
return sequence
def build_lm_labels(sequence, pad_token):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token] = -1
return padded
def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = sequence.clone()
mask[mask != pad_token] = 1
mask[mask == pad_token] = 0
return mask
def encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in story_lines
]
summary_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in summary_lines
]
story_token_ids = [
token for sentence in story_lines_token_ids for token in sentence
]
summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence
]
return story_token_ids, summary_token_ids
def compute_token_type_ids(batch, separator_token_id):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings = []
for sequence in batch:
sentence_num = 0
embeddings = []
for s in sequence:
if s == separator_token_id:
sentence_num += 1
embeddings.append(sentence_num % 2)
batch_embeddings.append(embeddings)
return torch.tensor(batch_embeddings)

View File

@ -0,0 +1,133 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from utils_summarization import (
compute_token_type_ids,
fit_to_block_size,
build_mask,
build_lm_labels,
process_story,
)
class SummarizationDataProcessingTest(unittest.TestCase):
def setUp(self):
self.block_size = 10
def test_fit_to_block_sequence_too_small(self):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence = [1, 2, 3, 4]
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
self.assertEqual(
fit_to_block_size(sequence, self.block_size, 0), expected_output
)
def test_fit_to_block_sequence_fit_exactly(self):
""" Do nothing if the sequence is the right size. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(
fit_to_block_size(sequence, self.block_size, 0), expected_output
)
def test_fit_to_block_sequence_too_big(self):
""" Truncate the sequence if it is too long. """
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(
fit_to_block_size(sequence, self.block_size, 0), expected_output
)
def test_process_story_no_highlights(self):
""" Processing a story with no highlights returns an empty list for the summary.
"""
raw_story = """It was the year of Our Lord one thousand seven hundred and
seventy-five.\n\nSpiritual revelations were conceded to England at that
favoured period, as at this."""
_, summary_lines = process_story(raw_story)
self.assertEqual(summary_lines, [])
def test_process_empty_story(self):
""" An empty story returns an empty collection of lines.
"""
raw_story = ""
story_lines, summary_lines = process_story(raw_story)
self.assertEqual(story_lines, [])
self.assertEqual(summary_lines, [])
def test_process_story_with_missing_period(self):
raw_story = (
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five\n\nSpiritual revelations were conceded to England "
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
)
story_lines, summary_lines = process_story(raw_story)
expected_story_lines = [
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
"Spiritual revelations were conceded to England at that favoured period, as at this.",
]
self.assertEqual(expected_story_lines, story_lines)
expected_summary_lines = ["It was the best of times."]
self.assertEqual(expected_summary_lines, summary_lines)
def test_build_lm_labels_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = sequence
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_lm_labels(self):
sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0])
expected = torch.tensor([1, 2, 3, 4, -1, -1, -1])
np.testing.assert_array_equal(
build_lm_labels(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask_no_padding(self):
sequence = torch.tensor([1, 2, 3, 4])
expected = torch.tensor([1, 1, 1, 1])
np.testing.assert_array_equal(
build_mask(sequence, 0).numpy(), expected.numpy()
)
def test_build_mask(self):
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
np.testing.assert_array_equal(
build_mask(sequence, 23).numpy(), expected.numpy()
)
def test_compute_token_type_ids(self):
separator = 101
batch = torch.tensor(
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
)
expected = torch.tensor(
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]]
)
result = compute_token_type_ids(batch, separator)
np.testing.assert_array_equal(result, expected)
if __name__ == "__main__":
unittest.main()

View File

@ -26,189 +26,220 @@ import torch
from torch import nn
class ModelWithBeamSearch(nn.Module):
class TransformerBeamSearch(nn.Module):
def __init__(
self,
model,
tokenizer,
batch_size,
beam_size,
start_token_id,
end_token_id,
pad_token_id,
min_length,
max_length,
alpha,
block_trigram=True,
alpha=0,
block_repeating_trigram=True,
):
"""
Attributes:
mask_word_id: token id that corresponds to the mask
"""
super(ModelWithBeamSearch, self).__init__()
super(TransformerBeamSearch, self).__init__()
self.model = model
self.tokenizer = tokenizer
self.start_token_id = tokenizer.start_token_id
self.end_token_id = tokenizer.end_token_id
self.pad_token_id = tokenizer.pad_token_id
self.beam_size = beam_size
self.start_token_id = start_token_id
self.end_token_id = end_token_id
self.pad_token_id = pad_token_id
self.min_length = min_length
self.max_length = max_length
self.alpha = alpha
self.block_trigram = block_trigram
def forward(self, input_ids, **kwargs):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
self.block_repeating_trigram = block_repeating_trigram
self.apply_length_penalty = False if alpha == 0 else True
self.alpha = alpha
# State of the beam
self.hypotheses = [[] for _ in range(batch_size)]
self.batch_offset = torch.arange(batch_size, dtype=torch.long)
self.beam_offset = torch.arange(
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
)
self.growing_beam = torch.full(
(batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
)
self.topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
).repeat(batch_size)
self.results = {
"prediction": [[] for _ in batch_size],
"scores": [[] for _ in batch_size],
}
self._step = 0
self.is_done = False
def step(self, log_probabilities):
""" Grows the beam by one step. """
self._step += 1
# The batch size changes as some beams finish so we define _B
vocab_size = log_probabilities.size(-1)
_B = log_probabilities.size(0) // self.beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities += self.topk_log_probabilities.view(-1, 1)
self.enforce_min_length(log_probabilities)
if self.block_repeating_trigram:
self.remove_repeating_trigrams(log_probabilities, _B)
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities, topk_ids = log_probabilities.topk(
log_probabilities.view(_B, self.beam_size * vocab_size),
self.beam_size,
dim=1,
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
topk_scores = topk_log_probabilities / self.length_penalty()
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids = topk_ids.div(vocab_size)
topk_token_ids = topk_ids.fmod(vocab_size)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view(
-1
)
# Append the last predictions
self.growing_beam = torch.cat(
[
self.growing_beam.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1),
],
1,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished = topk_token_ids.eq(self.end_token_id)
self.enforce_max_length()
is_top_beam_finished = is_finished[:, 0].eq(1)
# Save the finished searches
if is_finished.any():
predictions = self.growing_beam.view(
-1, self.beam_size, self.growing_beam.size(1)
)
for i in range(is_finished.size(0)):
if is_top_beam_finished[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
b = self.batch_offset[i]
for j in finished_hyp:
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if is_top_beam_finished[i]:
best_hyp = sorted(
self.hypotheses[b], key=lambda x: x[0], reverse=True
)
best_score, best_prediction = best_hyp[0]
self.results["scores"][b].append(best_score)
self.results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
if len(non_finished) == 0:
self.is_done = True
# Remove finished batches for the next step.
topk_log_probabilities = topk_log_probabilities.index_select(
0, non_finished
)
self.batch_offset = self.batch_offset.index_select(0, non_finished)
self.growing_beam = predictions.index_select(0, non_finished).view(
-1, self.growing_beam.size(-1)
)
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
return surviving_beams_rows
def forward(self, encoder_input_ids, **kwargs):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument: value
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
batch_size, _ = input_ids.size(0)
# Variables that keep track of the status of the search
hypotheses = [[] for _ in range(batch_size)]
batch_offset = torch.arange(batch_size, dtype=torch.long)
beam_offset = torch.arange(
0,
batch_size * self.beam_size,
step=self.beam_size,
dtype=torch.long,
)
growing_beam = torch.full(
(batch_size * self.beam_size, 1),
self.start_token_id,
dtype=torch.long,
)
topk_log_probabilities = torch.tensor(
[0.0] + [float("-inf")] * (self.beam_size - 1),
dtype=torch.float,
).repeat(batch_size)
# Forward pass on the encoder
encoder_outputs = self.encoder(input_ids, kwargs_encoder)
# forward pass on the encoder
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
kwargs_decoder["encoder_hidden_states"] = tile(
encoder_outputs, self.beam_size, dim=0
)
results = {}
results["predictions"] = [[] for _ in batch_size]
results["scores"] = [[] for _ in batch_size]
# grow the beam by generating sequences in an autoregressive way
self.growing_beam = torch.full(
(self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
)
for step in range(self.max_length):
decoder_input = growing_beam[:, -1]
outputs = self.decoder(decoder_input, kwargs_decoder)
decoder_input = self.growing_beam[:, -1]
outputs = self.model.decoder(decoder_input, kwargs_decoder)
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
vocab_size = log_probabilities.size(-1)
surviving_beams_rows = self.step(log_probabilities)
if self.is_done:
break
# The batch size changes as some beams finish so we define:
_B = log_probabilities.size(0) // self.beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities += topk_log_probabilities.view(-1, 1)
# if the beam has not attained the minimum required length we
# make the end token arbitrarily unlikely.
if step < self.min_length:
log_probabilities[self.end_token_id] = -1e20
# Remove repeating tri-grams
if(self.args.block_trigram):
if(step + 1 > 3):
for i in range(_B * self.beam_size):
tokens = [t for t in growing_beam[i]]
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities, topk_ids = log_probabilities.topk(
log_probabilities.view(_B, self.beam_size * vocab_size),
self.beam_size,
dim=1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha
topk_scores = topk_log_probabilities / length_penalty
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids = topk_ids.div(vocab_size)
topk_token_ids = topk_ids.fmod(vocab_size)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows = (
topk_beam_ids + beam_offset[:_B].view(-1, 1)
).view(-1)
# Append the last predictions
growing_beam = torch.cat(
[
growing_beam.index_select(0, surviving_beams_rows),
topk_token_ids.view(-1, 1),
],
1,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished = topk_token_ids.eq(self.end_token_id)
if step + 1 == self.max_length:
is_finished.fill_(1)
is_top_beam_finished = is_finished[:, 0].eq(1)
# Save the finished searches
if is_finished.any():
predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1))
for i in range(is_finished.size(0)):
if is_top_beam_finished[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
b = batch_offset[i]
for j in finished_hyp:
hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if is_top_beam_finished[i]:
best_hyp = sorted(
hypotheses[b], key=lambda x: x[0], reverse=True
)
best_score, best_prediction = best_hyp[0]
results["scores"][b].append(best_score)
results["predictions"][b].append(best_prediction)
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
growing_beam = predictions.index_select(0, non_finished).view(
-1, growing_beam.size(-1)
)
# Re-order the state for the next pass
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
"encoder_hidden_states"
].index_select(0, surviving_beams_rows)
return results
return self.results
def remove_repeating_trigrams(self, log_probabilities, _B):
if(self._step + 1 > 3):
for i in range(_B * self.beam_size):
tokens = [t for t in self.growing_beam[i]]
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
last_trigram = tuple(trigrams[-1])
if last_trigram in trigrams[:-1]:
log_probabilities[i] = -1e20
def enforce_min_length(self):
if self._step < self.min_length:
self.log_probabilities[self.end_token_id] = -1e20
def enforce_max_length(self):
if self._step + 1 == self.max_length:
self.is_finished.fill_(1)
def length_penalty(self):
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
def tile(x, count, dim=0):

View File

@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel):
"""
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# If a 2D encoder attention mask is provided for the cross-attention
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if encoder_attention_mask is not None:
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
if encoder_attention_mask.dim() == 3:
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
if encoder_attention_mask.dim() == 2:
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel):
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask)
encoder_attention_mask=encoder_extended_attention_mask)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel):
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Next token prediction loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel):
if lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :]
lm_labels = lm_labels[:, 1:]
prediction_scores = prediction_scores[:, :-1, :].contiguous()
lm_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1))
outputs = (seq2seq_loss,) + outputs
next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (next_token_loss,) + outputs
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,

View File

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module):
r"""
:class:`~transformers.Seq2seq` is a generic model class that will be
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
the library as encoder and (optionally) as decoder when created with
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module):
*model_args,
**kwargs
):
r""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module):
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument: value
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("encoder_model", None)
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
kwargs_encoder["is_decoder"] = False
encoder = AutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
)
encoder.config.is_decoder = False
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
kwargs_decoder["is_decoder"] = True
decoder = AutoModelWithLMHead.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder
)
decoder.config.is_decoder = True
model = cls(encoder, decoder)
@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module):
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary.
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = {
argument: value
argument[len("encoder_"):]: value
for argument, value in kwargs.items()
if not argument.startswith("decoder_")
if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value
argument[len("decoder_"):]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
}
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0][
-1
] # output of the encoder *stack*
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
else:
encoder_outputs = ()
# Decode
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq):
r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
name of or that path to a pretrained model is specified the encoder and
the decoder will be initialized with the pretrained weight (the
cross-attention will be intialized randomly if its weights are not
present).
It is possible to override this behavior and initialize, say, the decoder randomly
by creating it beforehand as follows
config = BertConfig.from_pretrained()
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
def __init__(self, *args, **kwargs):
super(Model2Model, self).__init__(*args, **kwargs)
self.tie_weights()
@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq):
model = super(Model2Model, cls).from_pretrained(
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
*args,
**kwargs
)
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if kwargs.get("decoder_initialize_randomly", False):
model.decoder.init_weights()
return model