from collections import deque import os import torch from torch.utils.data import Dataset # ------------ # Data loading # ------------ class CNNDailyMailDataset(Dataset): """ Abstracts the dataset used to train seq2seq models. CNN/Daily News: The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as sentences that are prefixed by the special `@highlight` line. To process the data, untar both datasets in the same folder, and pass the path to this folder as the "data_dir argument. The formatting code was inspired by [2]. [1] https://cs.nyu.edu/~kcho/ [2] https://github.com/abisee/cnn-dailymail/ """ def __init__(self, tokenizer, prefix="train", data_dir=""): assert os.path.isdir(data_dir) self.tokenizer = tokenizer # We initialize the class by listing all the files that contain # stories and summaries. Files are not read in memory given # the size of the corpus. self.stories_path = [] datasets = ("cnn", "dailymail") for dataset in datasets: path_to_stories = os.path.join(data_dir, dataset, "stories") story_filenames_list = os.listdir(path_to_stories) for story_filename in story_filenames_list: path_to_story = os.path.join(path_to_stories, story_filename) if not os.path.isfile(path_to_story): continue self.stories_path.append(path_to_story) def __len__(self): return len(self.stories_path) def __getitem__(self, idx): story_path = self.stories_path[idx] with open(story_path, encoding="utf-8") as source: raw_story = source.read() story_lines, summary_lines = process_story(raw_story) return story_lines, summary_lines def process_story(raw_story): """ Extract the story and summary from a story file. Attributes: raw_story (str): content of the story file as an utf-8 encoded string. Raises: IndexError: If the stoy is empty or contains no highlights. """ nonempty_lines = list( filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]) ) # for some unknown reason some lines miss a period, add it nonempty_lines = [_add_missing_period(line) for line in nonempty_lines] # gather article lines story_lines = [] lines = deque(nonempty_lines) while True: try: element = lines.popleft() if element.startswith("@highlight"): break story_lines.append(element) except IndexError: # if "@highlight" is absent from the file we pop # all elements until there is None. return story_lines, [] # gather summary lines summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines)) return story_lines, summary_lines def _add_missing_period(line): END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"] if line.startswith("@highlight"): return line if line[-1] in END_TOKENS: return line return line + "." # -------------------------- # Encoding and preprocessing # -------------------------- def fit_to_block_size(sequence, block_size, pad_token): """ Adapt the source and target sequences' lengths to the block size. If the sequence is shorter than the block size we pad it with -1 ids which correspond to padding tokens. """ if len(sequence) > block_size: return sequence[:block_size] else: sequence.extend([pad_token] * (block_size - len(sequence))) return sequence def build_lm_labels(sequence, pad_token): """ Padding token, encoded as 0, are represented by the value -1 so they are not taken into account in the loss computation. """ padded = sequence.clone() padded[padded == pad_token] = -1 return padded def build_mask(sequence, pad_token): """ Builds the mask. The attention mechanism will only attend to positions with value 1. """ mask = torch.ones_like(sequence) idx_pad_tokens = sequence == pad_token mask[idx_pad_tokens] = 0 return mask def encode_for_summarization(story_lines, summary_lines, tokenizer): """ Encode the story and summary lines, and join them as specified in [1] by using `[SEP] [CLS]` tokens to separate sentences. """ story_lines_token_ids = [ tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line)) for line in story_lines ] summary_lines_token_ids = [ tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line)) for line in summary_lines ] story_token_ids = [ token for sentence in story_lines_token_ids for token in sentence ] summary_token_ids = [ token for sentence in summary_lines_token_ids for token in sentence ] return story_token_ids, summary_token_ids def compute_token_type_ids(batch, separator_token_id): """ Segment embeddings as described in [1] The values {0,1} were found in the repository [2]. Attributes: batch: torch.Tensor, size [batch_size, block_size] Batch of input. separator_token_id: int The value of the token that separates the segments. [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." arXiv preprint arXiv:1908.08345 (2019). [2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217) """ batch_embeddings = [] for sequence in batch: sentence_num = 0 embeddings = [] for s in sequence: if s == separator_token_id: sentence_num += 1 embeddings.append(sentence_num % 2) batch_embeddings.append(embeddings) return torch.tensor(batch_embeddings)