mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-18 03:58:25 +06:00

This is the result of: $ black --line-length 119 examples templates transformers utils hubconf.py setup.py There's a lot of fairly long lines in the project. As a consequence, I'm picking the longest widely accepted line length, 119 characters. This is also Thomas' preference, because it allows for explicit variable names, to make the code easier to understand.
168 lines
5.6 KiB
Python
168 lines
5.6 KiB
Python
from collections import deque
|
|
import os
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
# ------------
|
|
# Data loading
|
|
# ------------
|
|
|
|
|
|
class SummarizationDataset(Dataset):
|
|
""" Abstracts the dataset used to train seq2seq models.
|
|
|
|
The class will process the documents that are located in the specified
|
|
folder. The preprocessing will work on any document that is reasonably
|
|
formatted. On the CNN/DailyMail dataset it will extract both the story
|
|
and the summary.
|
|
|
|
CNN/Daily News:
|
|
|
|
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
|
stored in different files; the summary appears at the end of the story as
|
|
sentences that are prefixed by the special `@highlight` line. To process
|
|
the data, untar both datasets in the same folder, and pass the path to this
|
|
folder as the "data_dir argument. The formatting code was inspired by [2].
|
|
|
|
[1] https://cs.nyu.edu/~kcho/
|
|
[2] https://github.com/abisee/cnn-dailymail/
|
|
"""
|
|
|
|
def __init__(self, path="", prefix="train"):
|
|
""" We initialize the class by listing all the documents to summarize.
|
|
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
|
"""
|
|
assert os.path.isdir(path)
|
|
|
|
self.documents = []
|
|
story_filenames_list = os.listdir(path)
|
|
for story_filename in story_filenames_list:
|
|
if "summary" in story_filename:
|
|
continue
|
|
path_to_story = os.path.join(path, story_filename)
|
|
if not os.path.isfile(path_to_story):
|
|
continue
|
|
self.documents.append(path_to_story)
|
|
|
|
def __len__(self):
|
|
""" Returns the number of documents. """
|
|
return len(self.documents)
|
|
|
|
def __getitem__(self, idx):
|
|
document_path = self.documents[idx]
|
|
document_name = document_path.split("/")[-1]
|
|
with open(document_path, encoding="utf-8") as source:
|
|
raw_story = source.read()
|
|
story_lines, summary_lines = process_story(raw_story)
|
|
return document_name, story_lines, summary_lines
|
|
|
|
|
|
def process_story(raw_story):
|
|
""" Extract the story and summary from a story file.
|
|
|
|
Attributes:
|
|
raw_story (str): content of the story file as an utf-8 encoded string.
|
|
|
|
Raises:
|
|
IndexError: If the stoy is empty or contains no highlights.
|
|
"""
|
|
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
|
|
|
# for some unknown reason some lines miss a period, add it
|
|
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
|
|
|
# gather article lines
|
|
story_lines = []
|
|
lines = deque(nonempty_lines)
|
|
while True:
|
|
try:
|
|
element = lines.popleft()
|
|
if element.startswith("@highlight"):
|
|
break
|
|
story_lines.append(element)
|
|
except IndexError:
|
|
# if "@highlight" is absent from the file we pop
|
|
# all elements until there is None, raising an exception.
|
|
return story_lines, []
|
|
|
|
# gather summary lines
|
|
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
|
|
|
return story_lines, summary_lines
|
|
|
|
|
|
def _add_missing_period(line):
|
|
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
|
|
if line.startswith("@highlight"):
|
|
return line
|
|
if line[-1] in END_TOKENS:
|
|
return line
|
|
return line + "."
|
|
|
|
|
|
# --------------------------
|
|
# Encoding and preprocessing
|
|
# --------------------------
|
|
|
|
|
|
def fit_to_block_size(sequence, block_size, pad_token_id):
|
|
""" Adapt the source and target sequences' lengths to the block size.
|
|
If the sequence is shorter we append padding token to the right of the sequence.
|
|
"""
|
|
if len(sequence) > block_size:
|
|
return sequence[:block_size]
|
|
else:
|
|
sequence.extend([pad_token_id] * (block_size - len(sequence)))
|
|
return sequence
|
|
|
|
|
|
def build_mask(sequence, pad_token_id):
|
|
""" Builds the mask. The attention mechanism will only attend to positions
|
|
with value 1. """
|
|
mask = torch.ones_like(sequence)
|
|
idx_pad_tokens = sequence == pad_token_id
|
|
mask[idx_pad_tokens] = 0
|
|
return mask
|
|
|
|
|
|
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
|
""" Encode the story and summary lines, and join them
|
|
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
|
sentences.
|
|
"""
|
|
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
|
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
|
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
|
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
|
|
|
return story_token_ids, summary_token_ids
|
|
|
|
|
|
def compute_token_type_ids(batch, separator_token_id):
|
|
""" Segment embeddings as described in [1]
|
|
|
|
The values {0,1} were found in the repository [2].
|
|
|
|
Attributes:
|
|
batch: torch.Tensor, size [batch_size, block_size]
|
|
Batch of input.
|
|
separator_token_id: int
|
|
The value of the token that separates the segments.
|
|
|
|
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
|
arXiv preprint arXiv:1908.08345 (2019).
|
|
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
|
"""
|
|
batch_embeddings = []
|
|
for sequence in batch:
|
|
sentence_num = -1
|
|
embeddings = []
|
|
for s in sequence:
|
|
if s == separator_token_id:
|
|
sentence_num += 1
|
|
embeddings.append(sentence_num % 2)
|
|
batch_embeddings.append(embeddings)
|
|
return torch.tensor(batch_embeddings)
|