work in progress

This commit is contained in:
Patrick von Platen 2020-03-06 14:39:28 +01:00
parent 5b3000d933
commit 7a11e925cf
6 changed files with 176 additions and 39 deletions

View File

@ -59,6 +59,7 @@ class BartConfig(PretrainedConfig):
output_past=False,
num_labels=3,
bos_token_id=0,
is_encoder_decoder=True,
**common_kwargs
):
r"""
@ -72,6 +73,7 @@ class BartConfig(PretrainedConfig):
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
is_encoder_decoder=is_encoder_decoder,
**common_kwargs,
)
self.vocab_size = vocab_size

View File

@ -75,9 +75,13 @@ class T5Config(PretrainedConfig):
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
is_encoder_decoder=True,
**kwargs
):
super().__init__(**kwargs)
super().__init__(
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model

View File

@ -65,11 +65,12 @@ class PretrainedConfig(object):
self.pruned_heads = kwargs.pop("pruned_heads", {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)
# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
self.min_length = kwargs.pop("max_length", 0)
self.min_length = kwargs.pop("min_length", 0)
self.do_sample = kwargs.pop("do_sample", False)
self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1)

View File

@ -957,7 +957,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
}
@staticmethod
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs):
def prepare_inputs_for_generation(decoder_input_ids, past, encoder_inputs, attention_mask):
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(attention_mask.shape, encoder_inputs.shape)
if past is None: # first step
encoder_outputs, decoder_cached_states = None, None
else:
@ -969,6 +970,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask
}
@staticmethod
@ -1132,6 +1134,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
lprobs[:, eos_token_id + 1 :] = -math.inf
assert self._do_output_past(outputs)
decoder_cache = outputs[1]
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty)
num_hypos = batch_size * num_beams

View File

@ -19,6 +19,8 @@ import logging
import os
import typing
import ipdb
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
@ -623,6 +625,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
@ -791,6 +794,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_beams >= num_return_sequences
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_ids if not set. Important that this is done after
# attention_mask is created
if pad_token_id is None and eos_token_ids is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0])
@ -812,15 +824,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
# TODO (PVP): probably not the best way to check whether model is encoder decoder
is_encoder_decoder = (
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
)
if is_encoder_decoder:
if self.config.is_encoder_decoder:
eos_token_id = eos_token_ids[0]
assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id"
assert eos_token_id is not None, "Encoder Decoder Models need to have a eos_token_id"
@ -828,8 +841,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
# eos_token_id,
bos_token_id,
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
dtype=torch.long,
device=next(self.parameters()).device,
)
@ -851,6 +865,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
effective_batch_size,
@ -859,6 +874,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_beams,
vocab_size,
encoder_inputs,
attention_mask,
)
else:
output = self._generate_no_beam_search(
@ -876,6 +892,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
eos_token_ids,
effective_batch_size,
encoder_inputs,
attention_mask,
)
return output
@ -896,6 +913,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
eos_token_ids,
batch_size,
encoder_inputs,
attention_mask
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
@ -906,7 +924,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
@ -922,7 +940,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
if no_repeat_ngram_size > 0:
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
for batch_idx in range(batch_size):
next_token_logits[
batch_idx, banned_tokens[batch_idx]
@ -968,6 +986,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if unfinished_sents.max() == 0:
break
# extend attention_mask for new generated input
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
cur_len = cur_len + 1
# if there are different sentences lengths in the batch, some batches have to be padded
@ -995,6 +1017,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
top_p,
repetition_penalty,
no_repeat_ngram_size,
bos_token_id,
pad_token_id,
eos_token_ids,
batch_size,
@ -1003,12 +1026,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_beams,
vocab_size,
encoder_inputs,
attention_mask,
):
""" Generate sequences for each example with beam search.
"""
is_encoder_decoder = (
hasattr(self, "model") and hasattr(self.model, "decoder") and hasattr(self.model, "encoder")
)
# generated hypotheses
generated_hyps = [
@ -1029,7 +1050,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, encoder_inputs=encoder_inputs, attention_mask=attention_mask)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
@ -1043,20 +1064,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
)
if cur_len < min_length and eos_token_ids is not None:
for eos_token_id in eos_token_ids:
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
if no_repeat_ngram_size > 0:
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
for batch_idx in range(batch_size):
next_token_logits[
batch_idx, banned_tokens[batch_idx]
] = -10000.0 # set eos token prob to 0 as is done for attention masks
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
next_token_logits[
:, eos_token_id
] = -10000.0 # set eos token prob to 0 as is done for attention masks
# force eos to be chosen at end of generation for encoder-decoder models
# TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently
if self.config.is_encoder_decoder:
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
if cur_len == 1:
self._force_token_ids_generation(next_token_logits, bos_token_id)
if cur_len == max_length - 1:
self._force_token_ids_generation(next_token_logits, eos_token_ids)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
@ -1091,19 +1119,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# do greedy beam search
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
# scores[:, pad_token_id] = -math.inf => seems very hacky here
# if self.config.is_encoder_decoder: # TODO(PVP) to be refactored later - do we need this boolean flag here?
# import math
# scores[scores != scores] = -math.inf # block nans => seems very hacky here
# scores[:, pad_token_id] = -math.inf # => seems very hacky here
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
# if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
if cur_len == max_length - 1: # FORCE EOS to be chosen
all_but_eos_mask = torch.tensor(
[x for x in range(vocab_size) if x not in eos_token_ids],
dtype=torch.long,
device=next(self.parameters()).device,
)
scores[:, all_but_eos_mask] = -10000.0
# if cur_len == 1: # Force BOS to be chosen => also very hacky ... seems also to work without this line
# scores[:, self.config.bos_token_id + 1 :] = -math.inf
# if cur_len == max_length - 1: # FORCE EOS to be chosen
# all_but_eos_mask = torch.tensor(
# [x for x in range(vocab_size) if x not in eos_token_ids],
# dtype=torch.long,
# device=next(self.parameters()).device,
# )
# scores[:, all_but_eos_mask] = -math.inf
# if eos_token_ids is not None and cur_len < min_length:
# for eos_token_id in eos_token_ids:
# scores[:, eos_token_id] = -math.inf # set eos token prob to 0 as is done for attention masks
#
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# if no_repeat_ngram_size > 0:
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
# banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len - 1)
# for batch_idx in range(batch_size):
# scores[
# batch_idx, banned_tokens[batch_idx]
# ] = -math.inf # set eos token prob to 0 as is done for attention masks
assert scores.size() == (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
@ -1126,7 +1168,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# if we are done with this sentence
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item()
next_scores[batch_idx].max().item(), cur_len=cur_len
)
if done[batch_idx]:
assert (
@ -1185,6 +1227,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if all(done):
break
# extend attention_mask for new generated input
if self.config.is_encoder_decoder is False:
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((1, attention_mask.shape[-1]))], dim=-1)
# update current length
cur_len = cur_len + 1
@ -1243,11 +1289,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if is_encoder_decoder:
if self.config.is_encoder_decoder:
# do not return first <BOS> token
return decoded[:, 1:]
return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0.
def _force_token_ids_generation(self, logits, token_ids):
if isinstance(token_ids, int):
token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor(
[x for x in range(self.config.vocab_size) if x not in token_ids],
dtype=torch.long,
device=next(self.parameters()).device,
)
assert len(logits.shape) == 2, "logits should be of rank 2 with shape: [batch_size, vocab_size]"
logits[:, all_but_token_ids_mask] = -10000.0
return logits
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []

View File

@ -16,6 +16,7 @@
import tempfile
import unittest
import ipdb
from transformers import is_torch_available
@ -425,7 +426,7 @@ class BartModelIntegrationTest(unittest.TestCase):
text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian"
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20
gen_tokens_bart = hf.generate_bart(tokens, num_beams=3, max_length=extra_len,) # repetition_penalty=10.,
gen_tokens_bart = hf.generate_bart(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10.,
gen_tokens = hf.generate(
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
) # repetition_penalty=10.,
@ -436,7 +437,49 @@ class BartModelIntegrationTest(unittest.TestCase):
self.assertEqual(expected_result, generated[0])
@slow
def test_cnn_summarization_same_as_fairseq_hard(self):
def test_cnn_summarization_same_as_fairseq_hard_single(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
SHORTER_ARTICLE = ' (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
EXPECTED_SUMMARY_SHORTER = "The Palestinian Authority becomes the 123rd member of the International Criminal Court. The move gives the court jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a move toward greater justice."
tokens = tok.encode(SHORTER_ARTICLE, return_tensors="pt").to(torch_device)
num_beams = 4
length_penalty = 2.0
max_length = 140
min_length = 55
no_repeat_ngram_size = 3
gen_tokens = hf.generate(
tokens,
num_beams=num_beams,
max_length=max_length + 2,
min_length=min_length + 1,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
do_sample=False
)
generated = [tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in gen_tokens]
self.assertEqual(EXPECTED_SUMMARY_SHORTER, generated[0])
gen_tokens_bart = hf.generate_bart(
tokens,
num_beams=num_beams,
max_length=max_length,
min_len=min_length,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size
)
generated_bart = [tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in gen_tokens_bart]
self.assertEqual(EXPECTED_SUMMARY_SHORTER, generated_bart[0])
@slow
def test_cnn_summarization_same_as_fairseq_hard_batch(self):
hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device)
tok = BartTokenizer.from_pretrained("bart-large")
@ -459,22 +502,47 @@ class BartModelIntegrationTest(unittest.TestCase):
pad_to_max_length=True,
return_tensors="pt",
)
max_length = 140
min_length = 55
self.assertEqual(1024, dct["input_ids"].shape[1])
hypotheses_batch = hf.generate(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4,
length_penalty=2.0,
max_length=140,
min_len=55,
max_length=max_length + 2,
min_length=min_length + 1,
no_repeat_ngram_size=3,
do_sample=False
)
decoded = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]
hypotheses_batch_bart = hf.generate_bart(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4,
length_penalty=2.0,
max_length=max_length,
min_len=min_length,
no_repeat_ngram_size=3,
)
decoded_bart = [
tok.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch_bart
]
ipdb.set_trace()
self.assertListEqual(
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
decoded,
)
self.assertListEqual(
[EXPECTED_SUMMARY_FRANCE, EXPECTED_SUMMARY_SHORTER, EXPECTED_SUMMARY_IRAN, EXPECTED_SUMMARY_SUBWAY],
decoded_bart,
)
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length