mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
load and prepare CNN/Daily Mail data
We write a function to load an preprocess the CNN/Daily Mail dataset as provided by Li Dong et al. The issue is that this dataset has already been tokenized by the authors, so we actually need to find the original, plain-text dataset if we want to apply it to all models.
This commit is contained in:
parent
d9d387afce
commit
67d10960ae
@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018 Microsoft and The HuggingFace Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -32,10 +32,13 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import BertConfig, Bert2Rnd, BertTokenizer
|
||||
|
||||
@ -48,8 +51,107 @@ def set_seed(args):
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
""" Abstracts a dataset used to train seq2seq models.
|
||||
|
||||
A seq2seq dataset consists in two files:
|
||||
- The source file that contains the source sequences, one line per sequence;
|
||||
- The target file contains the target sequences, one line per sequence.
|
||||
|
||||
The matching betwen source and target sequences is made on the basis of line numbers.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News dataset downloaded from [1] consists of two files that
|
||||
respectively contain the stories and the associated summaries. Each line
|
||||
corresponds to a different story. The files contain WordPiece tokens.
|
||||
|
||||
train.src: the longest story contains 6966 tokens, the shortest 12.
|
||||
Sentences are separated with `[SEP_i]` where i is an int between 0 and 9.
|
||||
|
||||
train.tgt: the longest summary contains 2467 tokens, the shortest 4.
|
||||
Sentences are separated with `[X_SEP]` tokens.
|
||||
|
||||
[1] https://github.com/microsoft/unilm
|
||||
"""
|
||||
def __init_(self, tokenizer, src_path='train.src', target_path='target.src' block_size=512):
|
||||
assert os.path.isfile(file_path)
|
||||
directory, filename = os.path.split(file_path)
|
||||
|
||||
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, file_name)
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
with open(cached_features_file, "rb") as source:
|
||||
self.examples = pickle.load(source)
|
||||
else:
|
||||
logger.info("Creating features from dataset at %s", directory)
|
||||
|
||||
self.examples = []
|
||||
with open(src_path, encoding="utf-8") as source, open(target_path, encoding="utf-8") as target:
|
||||
for line_src, line_tgt in zip(source, target)
|
||||
src_sequence = line_src.read()
|
||||
tgt_sequence = line_tgt.read()
|
||||
example = _truncate_and_concatenate(src_sequence, tgt_sequence, block_size)
|
||||
if example is not None:
|
||||
example = tokenizer.convert_tokens_to_ids(example)
|
||||
self.examples.append(example)
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self):
|
||||
return torch.tensor(self.examples[items])
|
||||
|
||||
|
||||
def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
|
||||
""" Concatenate the sequences and adapt their lengths to the block size.
|
||||
|
||||
Following [1] we perform the following transformations:
|
||||
- Add an [CLS] token at the beginning of the source sequence;
|
||||
- Add an [EOS] token at the end of the source and target sequences;
|
||||
- Concatenate the source and target + tokens sequence. If the concatenated sequence is
|
||||
longer than 512 we follow the 75%/25% rule in [1]: limit the source sequence's length to 384
|
||||
and the target sequence's length to 128.
|
||||
|
||||
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
||||
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
||||
"""
|
||||
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
|
||||
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
|
||||
|
||||
# the dataset contains special separator tokens that we remove for now.
|
||||
# They are of the form `[SEP_i]` in the source file, and `[X_SEP]` in the
|
||||
# target file.
|
||||
src_tokens = list(filter(lambda t: "[SEP_" in t, src_sequence.split(" ")))
|
||||
tgt_tokens = list(filter(lambda t: "_SEP]" in t, tgt_sequence.split(" ")))
|
||||
|
||||
# we dump the examples that are too small to fit in the block size for the
|
||||
# sake of simplicity. You can modify this by adding model-specific padding.
|
||||
if len(src_tokens) + len(src_tokens) + 3 < block_size:
|
||||
return None
|
||||
|
||||
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
|
||||
if len(src_tokens) > SRC_MAX_LENGTH
|
||||
if len(tgt_tokens) > TGT_MAX_LENGTH:
|
||||
src_tokens = src_tokens[:SRC_MAX_LENGTH]
|
||||
tgt_tokens = tgt_tokens[:TGT_MAX_LENGTH]
|
||||
else:
|
||||
src_tokens = src_tokens[block_size - len(tgt_tokens) - 3]
|
||||
else:
|
||||
if len(tgt_tokens) > TGT_MAX_LENGTH:
|
||||
tgt_tokens = tgt_tokens[block_size - len(src_tokens) - 3]
|
||||
|
||||
return ["[CLS]"] + src_tokens + ["[EOS]"] + tgt_tokens + ["[EOS]"]
|
||||
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
raise NotImplementedError
|
||||
dataset = TextDataset(tokenizer, file_path=args.train_data_file)
|
||||
return dataset
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
@ -102,4 +204,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
Loading…
Reference in New Issue
Block a user