mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
resolve PR comments
This commit is contained in:
parent
4c3ac4a7d8
commit
dfce409691
@ -16,10 +16,9 @@
|
||||
""" Finetuning seq2seq models for sequence generation."""
|
||||
|
||||
import argparse
|
||||
from collections import deque
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
|
||||
@ -29,7 +28,22 @@ import torch
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedSeq2seq, Model2Model
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertForMaskedLM,
|
||||
BertConfig,
|
||||
PreTrainedSeq2seq,
|
||||
Model2Model,
|
||||
)
|
||||
|
||||
from utils_summarization import (
|
||||
CNNDailyMailDataset,
|
||||
encode_for_summarization,
|
||||
fit_to_block_size,
|
||||
build_lm_labels,
|
||||
build_mask,
|
||||
compute_token_type_ids,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
@ -46,194 +60,41 @@ def set_seed(args):
|
||||
# ------------
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
|
||||
assert os.path.isdir(data_dir)
|
||||
|
||||
# Load the features that have already been computed, if any
|
||||
cached_features_file = os.path.join(
|
||||
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
|
||||
)
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
with open(cached_features_file, "rb") as source:
|
||||
self.examples = pickle.load(source)
|
||||
return
|
||||
|
||||
logger.info("Creating features from dataset at %s", data_dir)
|
||||
datasets = ["cnn", "dailymail"]
|
||||
|
||||
self.examples = {"source": [], "target": []}
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
story_filenames_list = os.listdir(path_to_stories)
|
||||
for story_filename in story_filenames_list:
|
||||
path_to_story = os.path.join(path_to_stories, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
|
||||
with open(path_to_story, encoding="utf-8") as source:
|
||||
raw_story = source.read()
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
if len(summary_lines) == 0 or len(story_lines) == 0:
|
||||
continue
|
||||
|
||||
story_token_ids, summary_token_ids = _encode_for_summarization(
|
||||
story_lines, summary_lines, tokenizer
|
||||
)
|
||||
story_seq = _fit_to_block_size(story_token_ids, block_size)
|
||||
self.examples["source"].append(story_seq)
|
||||
|
||||
summary_seq = _fit_to_block_size(summary_token_ids, block_size)
|
||||
self.examples["summary"].append(summary_seq)
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, items):
|
||||
return (
|
||||
torch.tensor(self.examples["source"][items]),
|
||||
torch.tensor(self.examples["target"][items]),
|
||||
)
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
""" Extract the story and summary from a story file.
|
||||
|
||||
Attributes:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the stoy is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(
|
||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||
)
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = deque(nonempty_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError:
|
||||
# if "@highlight" is absent from the file we pop
|
||||
# all elements until there is None.
|
||||
return story_lines, []
|
||||
|
||||
# gather summary lines
|
||||
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||
|
||||
return story_lines, summary_lines
|
||||
|
||||
|
||||
def _encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
""" Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [
|
||||
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
|
||||
for line in story_lines
|
||||
]
|
||||
summary_lines_token_ids = [
|
||||
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
|
||||
for line in summary_lines
|
||||
]
|
||||
|
||||
story_token_ids = [
|
||||
token for sentence in story_lines_token_ids for token in sentence
|
||||
]
|
||||
summary_token_ids = [
|
||||
token for sentence in summary_lines_token_ids for token in sentence
|
||||
]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + "."
|
||||
|
||||
|
||||
def _fit_to_block_size(sequence, block_size):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter than the block size we pad it with -1 ids
|
||||
which correspond to padding tokens.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([0] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def mask_padding_tokens(sequence):
|
||||
""" Padding token, encoded as 0, are represented by the value -1 in the
|
||||
masks """
|
||||
padded = sequence.clone()
|
||||
padded[padded == 0] = -1
|
||||
return padded
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = TextDataset(tokenizer, data_dir=args.data_dir)
|
||||
dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
""" Segment embeddings as described in [1]
|
||||
def collate(data, tokenizer, block_size):
|
||||
""" List of tuple as an input. """
|
||||
# remove the files with empty an story/summary, encode and fit to block
|
||||
data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
|
||||
data = [
|
||||
encode_for_summarization(story, summary, tokenizer) for story, summary in data
|
||||
]
|
||||
data = [
|
||||
(
|
||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id),
|
||||
fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
|
||||
)
|
||||
for story, summary in data
|
||||
]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
stories = torch.tensor([story for story, summary in data])
|
||||
summaries = torch.tensor([summary for story, summary in data])
|
||||
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
|
||||
decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
|
||||
lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)
|
||||
|
||||
Attributes:
|
||||
batch: torch.Tensor, size [batch_size, block_size]
|
||||
Batch of input.
|
||||
separator_token_id: int
|
||||
The value of the token that separates the segments.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
||||
"""
|
||||
batch_embeddings = []
|
||||
sentence_num = 0
|
||||
for sequence in batch:
|
||||
embeddings = []
|
||||
for s in sequence:
|
||||
if s == separator_token_id:
|
||||
sentence_num += 1
|
||||
embeddings.append(sentence_num % 2)
|
||||
batch_embeddings.append(embeddings)
|
||||
return torch.tensor(batch_embeddings)
|
||||
return (
|
||||
stories,
|
||||
summaries,
|
||||
encoder_token_type_ids,
|
||||
encoder_mask,
|
||||
decoder_mask,
|
||||
lm_labels,
|
||||
)
|
||||
|
||||
|
||||
# ----------
|
||||
@ -252,7 +113,7 @@ class BertSumOptimizer(object):
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
"""
|
||||
|
||||
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-9):
|
||||
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
|
||||
self.encoder = model.encoder
|
||||
self.decoder = model.decoder
|
||||
self.lr = lr
|
||||
@ -306,8 +167,12 @@ def train(args, model, tokenizer):
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_dataset = load_and_cache_examples(args, tokenizer)
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
collate_fn=model_collate_fn,
|
||||
)
|
||||
|
||||
# Training schedule
|
||||
@ -351,26 +216,23 @@ def train(args, model, tokenizer):
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
source, target = batch
|
||||
token_type_ids = compute_token_type_ids(source, tokenizer.cls_token_id)
|
||||
labels_src = mask_padding_tokens(source)
|
||||
labels_tgt = mask_padding_tokens(target)
|
||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
|
||||
|
||||
source = source.to(args.device)
|
||||
target = target.to(args.device)
|
||||
token_type_ids = token_type_ids.to(args.device)
|
||||
labels_src = labels_src.to(args.device)
|
||||
labels_tgt = labels_tgt.to(args.device)
|
||||
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
||||
encoder_mask = encoder_mask.to(args.device)
|
||||
decoder_mask = decoder_mask.to(args.device)
|
||||
lm_labels = lm_labels.to(args.device)
|
||||
|
||||
model.train()
|
||||
outputs = model(
|
||||
source,
|
||||
target,
|
||||
token_type_ids=token_type_ids,
|
||||
decoder_encoder_attention_mask=labels_src,
|
||||
decoder_attention_mask=labels_tgt,
|
||||
decoder_lm_labels=labels_tgt,
|
||||
decoder_initialize_randomly=True,
|
||||
encoder_token_type_ids=encoder_token_type_ids,
|
||||
encoder_attention_mask=encoder_mask,
|
||||
decoder_attention_mask=decoder_mask,
|
||||
decoder_lm_labels=lm_labels,
|
||||
)
|
||||
|
||||
loss = outputs[0]
|
||||
@ -421,21 +283,23 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
model.eval()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
source, target = batch
|
||||
labels_src = mask_padding_tokens(source)
|
||||
labels_tgt = mask_padding_tokens(target)
|
||||
source.to(args.device)
|
||||
target.to(args.device)
|
||||
labels_src.to(args.device)
|
||||
labels_tgt.to(args.device)
|
||||
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
|
||||
|
||||
source = source.to(args.device)
|
||||
target = target.to(args.device)
|
||||
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
|
||||
encoder_mask = encoder_mask.to(args.device)
|
||||
decoder_mask = decoder_mask.to(args.device)
|
||||
lm_labels = lm_labels.to(args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
source,
|
||||
target,
|
||||
decoder_encoder_attention_mask=labels_src,
|
||||
decoder_attention_mask=labels_tgt,
|
||||
decoder_lm_labels=labels_tgt,
|
||||
encoder_token_type_ids=encoder_token_type_ids,
|
||||
encoder_attention_mask=encoder_mask,
|
||||
decoder_attention_mask=decoder_mask,
|
||||
decoder_lm_labels=lm_labels,
|
||||
)
|
||||
lm_loss = outputs[0]
|
||||
eval_loss += lm_loss.mean().item()
|
||||
@ -525,7 +389,7 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
default=1,
|
||||
default=10,
|
||||
type=int,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
@ -558,9 +422,13 @@ def main():
|
||||
args.device = torch.device("cuda")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
# Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = Model2Model.from_pretrained(args.model_name_or_path)
|
||||
config = BertConfig.from_pretrained(args.model_name_or_path)
|
||||
decoder_model = BertForMaskedLM(config)
|
||||
model = Model2Model.from_pretrained(
|
||||
args.model_name_or_path, decoder_model=decoder_model
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
|
@ -1,76 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from run_summarization_finetuning import _fit_to_block_size, process_story
|
||||
|
||||
|
||||
class DataLoaderTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def test_truncate_sequence_too_small(self):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
|
||||
|
||||
def test_truncate_sequence_fit_exactly(self):
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
|
||||
|
||||
def test_truncate_sequence_too_big(self):
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
""" Processing a story with no highlights should raise an exception.
|
||||
"""
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
_, summary = process_story(raw_story)
|
||||
self.assertEqual(summary, [])
|
||||
|
||||
def test_process_empty_story(self):
|
||||
""" An empty story should also raise and exception.
|
||||
"""
|
||||
raw_story = ""
|
||||
story, summary = process_story(raw_story)
|
||||
self.assertEqual(story, [])
|
||||
self.assertEqual(summary, [])
|
||||
|
||||
def test_story_with_missing_period(self):
|
||||
raw_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||
)
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
|
||||
expected_story_lines = [
|
||||
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
|
||||
"Spiritual revelations were conceded to England at that favoured period, as at this.",
|
||||
]
|
||||
self.assertEqual(expected_story_lines, story_lines)
|
||||
|
||||
expected_summary_lines = ["It was the best of times."]
|
||||
self.assertEqual(expected_summary_lines, summary_lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
184
examples/utils_summarization.py
Normal file
184
examples/utils_summarization.py
Normal file
@ -0,0 +1,184 @@
|
||||
from collections import deque
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
# ------------
|
||||
# Data loading
|
||||
# ------------
|
||||
|
||||
|
||||
class CNNDailyMailDataset(Dataset):
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, prefix="train", data_dir=""):
|
||||
assert os.path.isdir(data_dir)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# We initialize the class by listing all the files that contain
|
||||
# stories and summaries. Files are not read in memory given
|
||||
# the size of the corpus.
|
||||
self.stories_path = []
|
||||
datasets = ("cnn", "dailymail")
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
story_filenames_list = os.listdir(path_to_stories)
|
||||
for story_filename in story_filenames_list:
|
||||
path_to_story = os.path.join(path_to_stories, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
self.stories_path.append(path_to_story)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.stories_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
story_path = self.stories_path[idx]
|
||||
with open(story_path, encoding="utf-8") as source:
|
||||
raw_story = source.read()
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
return story_lines, summary_lines
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
""" Extract the story and summary from a story file.
|
||||
|
||||
Attributes:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the stoy is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(
|
||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||
)
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = deque(nonempty_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError:
|
||||
# if "@highlight" is absent from the file we pop
|
||||
# all elements until there is None.
|
||||
return story_lines, []
|
||||
|
||||
# gather summary lines
|
||||
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||
|
||||
return story_lines, summary_lines
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + "."
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Encoding and preprocessing
|
||||
# --------------------------
|
||||
|
||||
|
||||
def fit_to_block_size(sequence, block_size, pad_token):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter than the block size we pad it with -1 ids
|
||||
which correspond to padding tokens.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([pad_token] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def build_lm_labels(sequence, pad_token):
|
||||
""" Padding token, encoded as 0, are represented by the value -1 so they
|
||||
are not taken into account in the loss computation. """
|
||||
padded = sequence.clone()
|
||||
padded[padded == pad_token] = -1
|
||||
return padded
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token):
|
||||
""" Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1. """
|
||||
mask = sequence.clone()
|
||||
mask[mask != pad_token] = 1
|
||||
mask[mask == pad_token] = 0
|
||||
return mask
|
||||
|
||||
|
||||
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
""" Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [
|
||||
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
|
||||
for line in story_lines
|
||||
]
|
||||
summary_lines_token_ids = [
|
||||
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
|
||||
for line in summary_lines
|
||||
]
|
||||
|
||||
story_token_ids = [
|
||||
token for sentence in story_lines_token_ids for token in sentence
|
||||
]
|
||||
summary_token_ids = [
|
||||
token for sentence in summary_lines_token_ids for token in sentence
|
||||
]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
""" Segment embeddings as described in [1]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
|
||||
Attributes:
|
||||
batch: torch.Tensor, size [batch_size, block_size]
|
||||
Batch of input.
|
||||
separator_token_id: int
|
||||
The value of the token that separates the segments.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
||||
"""
|
||||
batch_embeddings = []
|
||||
for sequence in batch:
|
||||
sentence_num = 0
|
||||
embeddings = []
|
||||
for s in sequence:
|
||||
if s == separator_token_id:
|
||||
sentence_num += 1
|
||||
embeddings.append(sentence_num % 2)
|
||||
batch_embeddings.append(embeddings)
|
||||
return torch.tensor(batch_embeddings)
|
133
examples/utils_summarization_test.py
Normal file
133
examples/utils_summarization_test.py
Normal file
@ -0,0 +1,133 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from utils_summarization import (
|
||||
compute_token_type_ids,
|
||||
fit_to_block_size,
|
||||
build_mask,
|
||||
build_lm_labels,
|
||||
process_story,
|
||||
)
|
||||
|
||||
|
||||
class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def test_fit_to_block_sequence_too_small(self):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
|
||||
def test_fit_to_block_sequence_fit_exactly(self):
|
||||
""" Do nothing if the sequence is the right size. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
|
||||
def test_fit_to_block_sequence_too_big(self):
|
||||
""" Truncate the sequence if it is too long. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
""" Processing a story with no highlights returns an empty list for the summary.
|
||||
"""
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
_, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_empty_story(self):
|
||||
""" An empty story returns an empty collection of lines.
|
||||
"""
|
||||
raw_story = ""
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(story_lines, [])
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_story_with_missing_period(self):
|
||||
raw_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||
)
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
|
||||
expected_story_lines = [
|
||||
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
|
||||
"Spiritual revelations were conceded to England at that favoured period, as at this.",
|
||||
]
|
||||
self.assertEqual(expected_story_lines, story_lines)
|
||||
|
||||
expected_summary_lines = ["It was the best of times."]
|
||||
self.assertEqual(expected_summary_lines, summary_lines)
|
||||
|
||||
def test_build_lm_labels_no_padding(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4])
|
||||
expected = sequence
|
||||
np.testing.assert_array_equal(
|
||||
build_lm_labels(sequence, 0).numpy(), expected.numpy()
|
||||
)
|
||||
|
||||
def test_build_lm_labels(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0])
|
||||
expected = torch.tensor([1, 2, 3, 4, -1, -1, -1])
|
||||
np.testing.assert_array_equal(
|
||||
build_lm_labels(sequence, 0).numpy(), expected.numpy()
|
||||
)
|
||||
|
||||
def test_build_mask_no_padding(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4])
|
||||
expected = torch.tensor([1, 1, 1, 1])
|
||||
np.testing.assert_array_equal(
|
||||
build_mask(sequence, 0).numpy(), expected.numpy()
|
||||
)
|
||||
|
||||
def test_build_mask(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(
|
||||
build_mask(sequence, 23).numpy(), expected.numpy()
|
||||
)
|
||||
|
||||
def test_compute_token_type_ids(self):
|
||||
separator = 101
|
||||
batch = torch.tensor(
|
||||
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]]
|
||||
)
|
||||
|
||||
result = compute_token_type_ids(batch, separator)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -26,189 +26,220 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ModelWithBeamSearch(nn.Module):
|
||||
class TransformerBeamSearch(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
beam_size,
|
||||
start_token_id,
|
||||
end_token_id,
|
||||
pad_token_id,
|
||||
min_length,
|
||||
max_length,
|
||||
alpha,
|
||||
block_trigram=True,
|
||||
alpha=0,
|
||||
block_repeating_trigram=True,
|
||||
):
|
||||
"""
|
||||
Attributes:
|
||||
mask_word_id: token id that corresponds to the mask
|
||||
"""
|
||||
super(ModelWithBeamSearch, self).__init__()
|
||||
super(TransformerBeamSearch, self).__init__()
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.start_token_id = tokenizer.start_token_id
|
||||
self.end_token_id = tokenizer.end_token_id
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
self.beam_size = beam_size
|
||||
self.start_token_id = start_token_id
|
||||
self.end_token_id = end_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.alpha = alpha
|
||||
self.block_trigram = block_trigram
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
self.block_repeating_trigram = block_repeating_trigram
|
||||
self.apply_length_penalty = False if alpha == 0 else True
|
||||
self.alpha = alpha
|
||||
|
||||
# State of the beam
|
||||
self.hypotheses = [[] for _ in range(batch_size)]
|
||||
self.batch_offset = torch.arange(batch_size, dtype=torch.long)
|
||||
self.beam_offset = torch.arange(
|
||||
0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long
|
||||
)
|
||||
self.growing_beam = torch.full(
|
||||
(batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
|
||||
)
|
||||
self.topk_log_probabilities = torch.tensor(
|
||||
[0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float
|
||||
).repeat(batch_size)
|
||||
self.results = {
|
||||
"prediction": [[] for _ in batch_size],
|
||||
"scores": [[] for _ in batch_size],
|
||||
}
|
||||
self._step = 0
|
||||
self.is_done = False
|
||||
|
||||
def step(self, log_probabilities):
|
||||
""" Grows the beam by one step. """
|
||||
self._step += 1
|
||||
|
||||
# The batch size changes as some beams finish so we define _B
|
||||
vocab_size = log_probabilities.size(-1)
|
||||
_B = log_probabilities.size(0) // self.beam_size
|
||||
|
||||
# Multiply each beam probability with the probability of the
|
||||
# next token (conditioned on the words in the beam).
|
||||
log_probabilities += self.topk_log_probabilities.view(-1, 1)
|
||||
|
||||
self.enforce_min_length(log_probabilities)
|
||||
if self.block_repeating_trigram:
|
||||
self.remove_repeating_trigrams(log_probabilities, _B)
|
||||
|
||||
# Find the `beam_size` (previous_beam + token) combinations with
|
||||
# the highest score
|
||||
topk_log_probabilities, topk_ids = log_probabilities.topk(
|
||||
log_probabilities.view(_B, self.beam_size * vocab_size),
|
||||
self.beam_size,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Apply the length penalty. The +1 accounts for the [EOS] token
|
||||
# that will be added if the beam ends.
|
||||
topk_scores = topk_log_probabilities / self.length_penalty()
|
||||
|
||||
# Retrieve the corresponding respective beam and token id
|
||||
# topk_token_ids[i] will be added to topk_beam_ids[i]
|
||||
topk_beam_ids = topk_ids.div(vocab_size)
|
||||
topk_token_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Retrieve the row index of the surviving beams in the original
|
||||
# view of the log_probabilities tensor
|
||||
surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view(
|
||||
-1
|
||||
)
|
||||
|
||||
# Append the last predictions
|
||||
self.growing_beam = torch.cat(
|
||||
[
|
||||
self.growing_beam.index_select(0, surviving_beams_rows),
|
||||
topk_token_ids.view(-1, 1),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
# Check if any of the beam searches has ended during this
|
||||
# growth step. Also if top beam (most probable) has ended
|
||||
# for one element of the batch.
|
||||
is_finished = topk_token_ids.eq(self.end_token_id)
|
||||
self.enforce_max_length()
|
||||
is_top_beam_finished = is_finished[:, 0].eq(1)
|
||||
|
||||
# Save the finished searches
|
||||
if is_finished.any():
|
||||
predictions = self.growing_beam.view(
|
||||
-1, self.beam_size, self.growing_beam.size(1)
|
||||
)
|
||||
for i in range(is_finished.size(0)):
|
||||
if is_top_beam_finished[i]:
|
||||
is_finished[i].fill_(1)
|
||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||
|
||||
# Store finished hypotheses for this batch.
|
||||
b = self.batch_offset[i]
|
||||
for j in finished_hyp:
|
||||
self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
|
||||
|
||||
# If the batch reached the end, save the best hypotheses
|
||||
# in terms of length-penalized score.
|
||||
if is_top_beam_finished[i]:
|
||||
best_hyp = sorted(
|
||||
self.hypotheses[b], key=lambda x: x[0], reverse=True
|
||||
)
|
||||
best_score, best_prediction = best_hyp[0]
|
||||
self.results["scores"][b].append(best_score)
|
||||
self.results["predictions"][b].append(best_prediction)
|
||||
|
||||
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
|
||||
if len(non_finished) == 0:
|
||||
self.is_done = True
|
||||
|
||||
# Remove finished batches for the next step.
|
||||
topk_log_probabilities = topk_log_probabilities.index_select(
|
||||
0, non_finished
|
||||
)
|
||||
self.batch_offset = self.batch_offset.index_select(0, non_finished)
|
||||
self.growing_beam = predictions.index_select(0, non_finished).view(
|
||||
-1, self.growing_beam.size(-1)
|
||||
)
|
||||
|
||||
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
|
||||
|
||||
return surviving_beams_rows
|
||||
|
||||
def forward(self, encoder_input_ids, **kwargs):
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
argument[len("encoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
||||
}
|
||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
||||
|
||||
batch_size, _ = input_ids.size(0)
|
||||
|
||||
# Variables that keep track of the status of the search
|
||||
hypotheses = [[] for _ in range(batch_size)]
|
||||
batch_offset = torch.arange(batch_size, dtype=torch.long)
|
||||
beam_offset = torch.arange(
|
||||
0,
|
||||
batch_size * self.beam_size,
|
||||
step=self.beam_size,
|
||||
dtype=torch.long,
|
||||
)
|
||||
growing_beam = torch.full(
|
||||
(batch_size * self.beam_size, 1),
|
||||
self.start_token_id,
|
||||
dtype=torch.long,
|
||||
)
|
||||
topk_log_probabilities = torch.tensor(
|
||||
[0.0] + [float("-inf")] * (self.beam_size - 1),
|
||||
dtype=torch.float,
|
||||
).repeat(batch_size)
|
||||
|
||||
# Forward pass on the encoder
|
||||
encoder_outputs = self.encoder(input_ids, kwargs_encoder)
|
||||
# forward pass on the encoder
|
||||
encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder)
|
||||
kwargs_decoder["encoder_hidden_states"] = tile(
|
||||
encoder_outputs, self.beam_size, dim=0
|
||||
)
|
||||
|
||||
results = {}
|
||||
results["predictions"] = [[] for _ in batch_size]
|
||||
results["scores"] = [[] for _ in batch_size]
|
||||
|
||||
# grow the beam by generating sequences in an autoregressive way
|
||||
self.growing_beam = torch.full(
|
||||
(self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long
|
||||
)
|
||||
for step in range(self.max_length):
|
||||
decoder_input = growing_beam[:, -1]
|
||||
outputs = self.decoder(decoder_input, kwargs_decoder)
|
||||
decoder_input = self.growing_beam[:, -1]
|
||||
outputs = self.model.decoder(decoder_input, kwargs_decoder)
|
||||
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
|
||||
vocab_size = log_probabilities.size(-1)
|
||||
surviving_beams_rows = self.step(log_probabilities)
|
||||
if self.is_done:
|
||||
break
|
||||
|
||||
# The batch size changes as some beams finish so we define:
|
||||
_B = log_probabilities.size(0) // self.beam_size
|
||||
|
||||
# Multiply each beam probability with the probability of the
|
||||
# next token (conditioned on the words in the beam).
|
||||
log_probabilities += topk_log_probabilities.view(-1, 1)
|
||||
|
||||
# if the beam has not attained the minimum required length we
|
||||
# make the end token arbitrarily unlikely.
|
||||
if step < self.min_length:
|
||||
log_probabilities[self.end_token_id] = -1e20
|
||||
|
||||
# Remove repeating tri-grams
|
||||
if(self.args.block_trigram):
|
||||
if(step + 1 > 3):
|
||||
for i in range(_B * self.beam_size):
|
||||
tokens = [t for t in growing_beam[i]]
|
||||
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
|
||||
last_trigram = tuple(trigrams[-1])
|
||||
if last_trigram in trigrams[:-1]:
|
||||
log_probabilities[i] = -1e20
|
||||
|
||||
# Find the `beam_size` (previous_beam + token) combinations with
|
||||
# the highest score
|
||||
topk_log_probabilities, topk_ids = log_probabilities.topk(
|
||||
log_probabilities.view(_B, self.beam_size * vocab_size),
|
||||
self.beam_size,
|
||||
dim=1
|
||||
)
|
||||
|
||||
# Apply the length penalty. The +1 accounts for the [EOS] token
|
||||
# that will be added if the beam ends.
|
||||
length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha
|
||||
topk_scores = topk_log_probabilities / length_penalty
|
||||
|
||||
# Retrieve the corresponding respective beam and token id
|
||||
# topk_token_ids[i] will be added to topk_beam_ids[i]
|
||||
topk_beam_ids = topk_ids.div(vocab_size)
|
||||
topk_token_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Retrieve the row index of the surviving beams in the original
|
||||
# view of the log_probabilities tensor
|
||||
surviving_beams_rows = (
|
||||
topk_beam_ids + beam_offset[:_B].view(-1, 1)
|
||||
).view(-1)
|
||||
|
||||
# Append the last predictions
|
||||
growing_beam = torch.cat(
|
||||
[
|
||||
growing_beam.index_select(0, surviving_beams_rows),
|
||||
topk_token_ids.view(-1, 1),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
# Check if any of the beam searches has ended during this
|
||||
# growth step. Also if top beam (most probable) has ended
|
||||
# for one element of the batch.
|
||||
is_finished = topk_token_ids.eq(self.end_token_id)
|
||||
if step + 1 == self.max_length:
|
||||
is_finished.fill_(1)
|
||||
is_top_beam_finished = is_finished[:, 0].eq(1)
|
||||
|
||||
# Save the finished searches
|
||||
if is_finished.any():
|
||||
predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1))
|
||||
for i in range(is_finished.size(0)):
|
||||
if is_top_beam_finished[i]:
|
||||
is_finished[i].fill_(1)
|
||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||
|
||||
# Store finished hypotheses for this batch.
|
||||
b = batch_offset[i]
|
||||
for j in finished_hyp:
|
||||
hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
|
||||
|
||||
# If the batch reached the end, save the best hypotheses
|
||||
# in terms of length-penalized score.
|
||||
if is_top_beam_finished[i]:
|
||||
best_hyp = sorted(
|
||||
hypotheses[b], key=lambda x: x[0], reverse=True
|
||||
)
|
||||
best_score, best_prediction = best_hyp[0]
|
||||
results["scores"][b].append(best_score)
|
||||
results["predictions"][b].append(best_prediction)
|
||||
|
||||
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
|
||||
if len(non_finished) == 0:
|
||||
break
|
||||
|
||||
# Remove finished batches for the next step.
|
||||
topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished)
|
||||
batch_offset = batch_offset.index_select(0, non_finished)
|
||||
growing_beam = predictions.index_select(0, non_finished).view(
|
||||
-1, growing_beam.size(-1)
|
||||
)
|
||||
|
||||
# Re-order the state for the next pass
|
||||
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
|
||||
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
|
||||
"encoder_hidden_states"
|
||||
].index_select(0, surviving_beams_rows)
|
||||
|
||||
return results
|
||||
return self.results
|
||||
|
||||
def remove_repeating_trigrams(self, log_probabilities, _B):
|
||||
if(self._step + 1 > 3):
|
||||
for i in range(_B * self.beam_size):
|
||||
tokens = [t for t in self.growing_beam[i]]
|
||||
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
|
||||
last_trigram = tuple(trigrams[-1])
|
||||
if last_trigram in trigrams[:-1]:
|
||||
log_probabilities[i] = -1e20
|
||||
|
||||
def enforce_min_length(self):
|
||||
if self._step < self.min_length:
|
||||
self.log_probabilities[self.end_token_id] = -1e20
|
||||
|
||||
def enforce_max_length(self):
|
||||
if self._step + 1 == self.max_length:
|
||||
self.is_finished.fill_(1)
|
||||
|
||||
def length_penalty(self):
|
||||
return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
|
@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel):
|
||||
"""
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel):
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# If a 2D encoder attention mask is provided for the cross-attention
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel):
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask)
|
||||
encoder_attention_mask=encoder_extended_attention_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
**masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Masked language modeling loss.
|
||||
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Next token prediction loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
if lm_labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
prediction_scores = prediction_scores[:, :-1, :]
|
||||
lm_labels = lm_labels[:, 1:]
|
||||
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
lm_labels = lm_labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1))
|
||||
outputs = (seq2seq_loss,) + outputs
|
||||
next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
||||
outputs = (next_token_loss,) + outputs
|
||||
|
||||
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
|
||||
return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
||||
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PreTrainedSeq2seq(nn.Module):
|
||||
r"""
|
||||
:class:`~transformers.Seq2seq` is a generic model class that will be
|
||||
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
|
||||
instantiated as a Seq2seq model with one of the base model classes of
|
||||
the library as encoder and (optionally) as decoder when created with
|
||||
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
|
||||
@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
*model_args,
|
||||
**kwargs
|
||||
):
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||
of the library from pre-trained model checkpoints.
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
|
||||
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
|
||||
"""
|
||||
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as a whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
argument[len("encoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
||||
}
|
||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
||||
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs_encoder.pop("encoder_model", None)
|
||||
encoder = kwargs_encoder.pop("model", None)
|
||||
if encoder is None:
|
||||
kwargs_encoder["is_decoder"] = False
|
||||
encoder = AutoModel.from_pretrained(
|
||||
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
||||
)
|
||||
encoder.config.is_decoder = False
|
||||
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
kwargs_decoder["is_decoder"] = True
|
||||
decoder = AutoModelWithLMHead.from_pretrained(
|
||||
decoder_pretrained_model_name_or_path, **kwargs_decoder
|
||||
)
|
||||
decoder.config.is_decoder = True
|
||||
|
||||
model = cls(encoder, decoder)
|
||||
|
||||
@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
"""
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
argument[len("encoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
if argument.startswith("encoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_common = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not (argument.startswith("encoder_") or argument.startswith("decoder_"))
|
||||
}
|
||||
kwargs_decoder = dict(kwargs_common, **kwargs_decoder)
|
||||
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0][
|
||||
-1
|
||||
] # output of the encoder *stack*
|
||||
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
class Model2Model(PreTrainedSeq2seq):
|
||||
r"""
|
||||
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
|
||||
where both of the encoder and decoder are of the same family. If the
|
||||
name of or that path to a pretrained model is specified the encoder and
|
||||
the decoder will be initialized with the pretrained weight (the
|
||||
cross-attention will be intialized randomly if its weights are not
|
||||
present).
|
||||
|
||||
It is possible to override this behavior and initialize, say, the decoder randomly
|
||||
by creating it beforehand as follows
|
||||
|
||||
config = BertConfig.from_pretrained()
|
||||
decoder = BertForMaskedLM(config)
|
||||
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model2Model, self).__init__(*args, **kwargs)
|
||||
self.tie_weights()
|
||||
@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq):
|
||||
model = super(Model2Model, cls).from_pretrained(
|
||||
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Some architectures require for the decoder to be initialized randomly
|
||||
# before fine-tuning.
|
||||
if kwargs.get("decoder_initialize_randomly", False):
|
||||
model.decoder.init_weights()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user