mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
process the raw CNN/Daily Mail dataset
the data provided by Li Dong et al. were already tokenized, which means that they are not compatible with all the models in the library. We thus process the raw data directly and tokenize them using the models' tokenizers.
This commit is contained in:
parent
67d10960ae
commit
447fffb21f
@ -17,9 +17,9 @@
|
||||
|
||||
We use the procedure described in [1] to finetune models for sequence
|
||||
generation. Let S1 and S2 be the source and target sequence respectively; we
|
||||
pack them using the start of sequence [SOS] and end of sequence [EOS] token:
|
||||
pack them using the start of sequence [EOS] and end of sequence [EOS] token:
|
||||
|
||||
[SOS] S1 [EOS] S2 [EOS]
|
||||
[CLS] S1 [EOS] S2 [EOS]
|
||||
|
||||
We then mask a fixed percentage of token from S2 at random and learn to predict
|
||||
the masked words. [EOS] can be masked during finetuning so the model learns to
|
||||
@ -31,6 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dequeue
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
@ -54,7 +55,7 @@ def set_seed(args):
|
||||
class TextDataset(Dataset):
|
||||
""" Abstracts a dataset used to train seq2seq models.
|
||||
|
||||
A seq2seq dataset consists in two files:
|
||||
A seq2seq dataset consists of two files:
|
||||
- The source file that contains the source sequences, one line per sequence;
|
||||
- The target file contains the target sequences, one line per sequence.
|
||||
|
||||
@ -62,43 +63,53 @@ class TextDataset(Dataset):
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News dataset downloaded from [1] consists of two files that
|
||||
respectively contain the stories and the associated summaries. Each line
|
||||
corresponds to a different story. The files contain WordPiece tokens.
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
|
||||
in different files where the summary sentences are indicated by the special `@highlight` token.
|
||||
To process the data, untar both datasets in the same folder, and path the path to this
|
||||
folder as the "train_data_file" argument. The formatting code was inspired by [2].
|
||||
|
||||
train.src: the longest story contains 6966 tokens, the shortest 12.
|
||||
Sentences are separated with `[SEP_i]` where i is an int between 0 and 9.
|
||||
|
||||
train.tgt: the longest summary contains 2467 tokens, the shortest 4.
|
||||
Sentences are separated with `[X_SEP]` tokens.
|
||||
|
||||
[1] https://github.com/microsoft/unilm
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
def __init_(self, tokenizer, src_path='train.src', target_path='target.src' block_size=512):
|
||||
assert os.path.isfile(file_path)
|
||||
directory, filename = os.path.split(file_path)
|
||||
def __init_(self, tokenizer, data_dir='', block_size=512):
|
||||
assert os.path.isdir(data_dir)
|
||||
|
||||
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, file_name)
|
||||
# Load features that have already been computed if present
|
||||
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir)
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
with open(cached_features_file, "rb") as source:
|
||||
self.examples = pickle.load(source)
|
||||
else:
|
||||
logger.info("Creating features from dataset at %s", directory)
|
||||
return
|
||||
|
||||
self.examples = []
|
||||
with open(src_path, encoding="utf-8") as source, open(target_path, encoding="utf-8") as target:
|
||||
for line_src, line_tgt in zip(source, target)
|
||||
src_sequence = line_src.read()
|
||||
tgt_sequence = line_tgt.read()
|
||||
example = _truncate_and_concatenate(src_sequence, tgt_sequence, block_size)
|
||||
if example is not None:
|
||||
example = tokenizer.convert_tokens_to_ids(example)
|
||||
self.examples.append(example)
|
||||
logger.info("Creating features from dataset at %s", directory)
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
|
||||
# we need to iterate over both the cnn and the dailymail dataset
|
||||
datasets = ['cnn', 'dailymail']
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
assert os.path.isdir(path_to_stories)
|
||||
|
||||
stories_files = os.listdir(path_to_stories)
|
||||
for story_file in stories_files:
|
||||
path_to_story = os.path.join(path_to_stories, "story_file")
|
||||
if !os.path.isfile(path_to_story):
|
||||
continue
|
||||
|
||||
with open(path_to_story, encoding="utf-8") as source:
|
||||
try:
|
||||
story, summary = process_story(source)
|
||||
except IndexError:
|
||||
continue
|
||||
|
||||
src_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||
tgt_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||
example = _truncate_and_concatenate(src_sequence, tgt_sequence, blocksize)
|
||||
self.examples.append(example)
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
@ -107,6 +118,46 @@ class TextDataset(Dataset):
|
||||
return torch.tensor(self.examples[items])
|
||||
|
||||
|
||||
def process_story(story_file):
|
||||
""" Process the text contained in a story file.
|
||||
Returns the story and the summary
|
||||
"""
|
||||
file_lines = list(filter(lambda x: len(x)!=0, [line.strip() for lines in story_file]))
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
file_lines = [_add_missing_period(line) for line in file_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = dequeue(file_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError as ie: # if "@highlight" absent from file
|
||||
raise ie
|
||||
|
||||
# gather summary lines
|
||||
highlights_lines = list(filter(lambda t: !t.startswith("@highlight"), lines))
|
||||
|
||||
# join the lines
|
||||
story = " ".join(story_lines)
|
||||
summary = " ".join(highlights_lines)
|
||||
|
||||
return story, summary
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"]
|
||||
if line == "@highlight":
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + " ."
|
||||
|
||||
|
||||
def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
|
||||
""" Concatenate the sequences and adapt their lengths to the block size.
|
||||
|
||||
@ -123,12 +174,6 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
|
||||
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
|
||||
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
|
||||
|
||||
# the dataset contains special separator tokens that we remove for now.
|
||||
# They are of the form `[SEP_i]` in the source file, and `[X_SEP]` in the
|
||||
# target file.
|
||||
src_tokens = list(filter(lambda t: "[SEP_" in t, src_sequence.split(" ")))
|
||||
tgt_tokens = list(filter(lambda t: "_SEP]" in t, tgt_sequence.split(" ")))
|
||||
|
||||
# we dump the examples that are too small to fit in the block size for the
|
||||
# sake of simplicity. You can modify this by adding model-specific padding.
|
||||
if len(src_tokens) + len(src_tokens) + 3 < block_size:
|
||||
@ -145,6 +190,7 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size):
|
||||
if len(tgt_tokens) > TGT_MAX_LENGTH:
|
||||
tgt_tokens = tgt_tokens[block_size - len(src_tokens) - 3]
|
||||
|
||||
# I add the special tokens manually, but this should be done by the tokenizer. That's the next step.
|
||||
return ["[CLS]"] + src_tokens + ["[EOS]"] + tgt_tokens + ["[EOS]"]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user