transformers/examples/utils_summarization.py
2019-10-29 17:10:20 +01:00

185 lines
5.9 KiB
Python

from collections import deque
import os
import torch
from torch.utils.data import Dataset
# ------------
# Data loading
# ------------
class CNNDailyMailDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, tokenizer, prefix="train", data_dir=""):
assert os.path.isdir(data_dir)
self.tokenizer = tokenizer
# We initialize the class by listing all the files that contain
# stories and summaries. Files are not read in memory given
# the size of the corpus.
self.stories_path = []
datasets = ("cnn", "dailymail")
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
story_filenames_list = os.listdir(path_to_stories)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, story_filename)
if not os.path.isfile(path_to_story):
continue
self.stories_path.append(path_to_story)
def __len__(self):
return len(self.stories_path)
def __getitem__(self, idx):
story_path = self.stories_path[idx]
with open(story_path, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
return story_lines, summary_lines
def process_story(raw_story):
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
nonempty_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
# gather article lines
story_lines = []
lines = deque(nonempty_lines)
while True:
try:
element = lines.popleft()
if element.startswith("@highlight"):
break
story_lines.append(element)
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return story_lines, []
# gather summary lines
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
return story_lines, summary_lines
def _add_missing_period(line):
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
return line
return line + "."
# --------------------------
# Encoding and preprocessing
# --------------------------
def fit_to_block_size(sequence, block_size, pad_token):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([pad_token] * (block_size - len(sequence)))
return sequence
def build_lm_labels(sequence, pad_token):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded = sequence.clone()
padded[padded == pad_token] = -1
return padded
def build_mask(sequence, pad_token):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask = sequence.clone()
mask[mask != pad_token] = 1
mask[mask == pad_token] = 0
return mask
def encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in story_lines
]
summary_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in summary_lines
]
story_token_ids = [
token for sentence in story_lines_token_ids for token in sentence
]
summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence
]
return story_token_ids, summary_token_ids
def compute_token_type_ids(batch, separator_token_id):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings = []
for sequence in batch:
sentence_num = 0
embeddings = []
for s in sequence:
if s == separator_token_id:
sentence_num += 1
embeddings.append(sentence_num % 2)
batch_embeddings.append(embeddings)
return torch.tensor(batch_embeddings)