From aeb18b9224e5ea20128e45f2c1d886422bc5a59e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Feb 2021 15:00:18 +0100 Subject: [PATCH] Adding new `encoder_no_repeat_ngram_size` to `generate`. (#9984) Adding new `encoder_no_repeat_ngram_size` to `generate`. Blenderbot results seemed off compared to original ParlAI script: `https://parl.ai/projects/recipes/`. Notably the model seems to repeat a lot what was said during the conversation. The actual problem was that `no_repeat_ngram_size` actually applies to the `encoder_input_ids` but HF's `no_repeat_ngram_size` applies to the previously generated ids (within the decoder). The history conversation of blenderbot is within the `encoder` part so that explains why HF's implementation had the repetitions. This fix was focused on blenderbot *not* small and added tests for those because they are quite different in configuration. This change includes: - Adding a new EncoderNoRepeatLogitProcessor. - Adding 1 new arg to `generate` (`encoder_no_repeat_ngram_size`) - Adding 1 new config parameter `encoder_no_repeat_ngram_size`. - Adding 2 tests, one for the pipeline (high level, inputs exhibited repeat behavior, one low level for EncoderNoRepeatLogitProcessor) - Factored NoRepeatLogitProcessor so that logic could be reused. Further work: - Blenderbot conversational pipeline still does not behave correctly as they way input is prepared within the pipeline is still incorrect (follow up PR) - Blenderbot allows the bot to have personas, which is done by prepending "your personna: XXXX" to the input, this could be explored too in a follow up PR. @patrickvonplaten @LysandreJik * Update src/transformers/generation_logits_process.py Co-authored-by: Patrick von Platen * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen * Update src/transformers/generation_utils.py Co-authored-by: Patrick von Platen * Update src/transformers/configuration_utils.py Co-authored-by: Patrick von Platen * Doc quality. * Fixing test. * Last fixes. * Fixing to account for batch_size. * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/generation_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/configuration_utils.py | 4 + src/transformers/generation_logits_process.py | 96 ++++++++++++++----- src/transformers/generation_utils.py | 25 +++++ .../blenderbot/configuration_blenderbot.py | 2 + tests/test_generation_logits_process.py | 63 ++++++++++++ tests/test_pipelines_conversational.py | 41 ++++++++ 6 files changed, 209 insertions(+), 22 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index eeb8563cbe4..6a982d8b2d2 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -117,6 +117,9 @@ class PretrainedConfig(object): - **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default in the :obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of that size can only occur once. + - **encoder_no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by + default in the :obj:`generate` method of the model for ``encoder_no_repeat_ngram_size``. If set to int > 0, + all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the ``decoder_input_ids``. - **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be generated that will be used by default in the :obj:`generate` method of the model. In order to get the tokens of the words that should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, @@ -205,6 +208,7 @@ class PretrainedConfig(object): self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) + self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0) self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index ac77b6b224a..85d2c9df369 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -235,6 +235,41 @@ class TopKLogitsWarper(LogitsWarper): return scores +def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + return generated_ngrams + + +def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - ngram_size + ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) + return banned_ngrams.get(ngram_idx, []) + + +def _calc_banned_ngram_tokens( + ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int +) -> List[Iterable[int]]: + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + + generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) + + banned_tokens = [ + _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) + for hypo_idx in range(num_hypos) + ] + return banned_tokens + + class NoRepeatNGramLogitsProcessor(LogitsProcessor): r""" :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq @@ -253,36 +288,53 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: num_batch_hypotheses = scores.shape[0] cur_len = input_ids.shape[-1] - banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len) + banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) for i, banned_tokens in enumerate(banned_batch_tokens): scores[i, banned_tokens] = -float("inf") return scores - def _calc_banned_ngram_tokens( - self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int - ) -> List[Iterable[int]]: - """Copied from fairseq for no_repeat_ngram in beam_search""" - if cur_len + 1 < self.ngram_size: - # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet - return [[] for _ in range(num_hypos)] - generated_ngrams = [{} for _ in range(num_hypos)] - for idx in range(num_hypos): - gen_tokens = prev_input_ids[idx].tolist() - generated_ngram = generated_ngrams[idx] - for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]): - prev_ngram_tuple = tuple(ngram[:-1]) - generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] - def _get_generated_ngrams(hypo_idx): - # Before decoding the next token, prevent decoding of ngrams that have already appeared - start_idx = cur_len + 1 - self.ngram_size - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) - return generated_ngrams[hypo_idx].get(ngram_idx, []) +class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + :class:`transformers.LogitsProcessor` that enforces no repetition of encoder input ids n-grams for the decoder ids. + See `ParlAI `__. - banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] - return banned_tokens + Args: + encoder_ngram_size (:obj:`int`): + All ngrams of size :obj:`ngram_size` can only occur within the encoder input ids. + encoder_input_ids (:obj:`int`): + The encoder_input_ids that should not be repeated within the decoder ids. + """ + + def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor): + if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0: + raise ValueError( + f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}" + ) + self.ngram_size = encoder_ngram_size + if len(encoder_input_ids.shape) == 1: + encoder_input_ids = encoder_input_ids.unsqueeze(0) + self.batch_size = encoder_input_ids.shape[0] + self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # B x num_beams + num_hypos = scores.shape[0] + num_beams = num_hypos // self.batch_size + cur_len = input_ids.shape[-1] + banned_batch_tokens = [ + _get_generated_ngrams( + self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len + ) + for hypo_idx in range(num_hypos) + ] + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores class NoBadWordsLogitsProcessor(LogitsProcessor): diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 3f933940c69..fcf01ab401e 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -23,6 +23,7 @@ from torch.nn import functional as F from .file_utils import ModelOutput from .generation_beam_search import BeamScorer, BeamSearchScorer from .generation_logits_process import ( + EncoderNoRepeatNGramLogitsProcessor, HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, @@ -537,6 +538,8 @@ class GenerationMixin: self, repetition_penalty: float, no_repeat_ngram_size: int, + encoder_no_repeat_ngram_size: int, + encoder_input_ids: torch.LongTensor, bad_words_ids: List[List[int]], min_length: int, eos_token_id: int, @@ -555,6 +558,11 @@ class GenerationMixin: no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size ) + encoder_no_repeat_ngram_size = ( + encoder_no_repeat_ngram_size + if encoder_no_repeat_ngram_size is not None + else self.config.encoder_no_repeat_ngram_size + ) bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids min_length = min_length if min_length is not None else self.config.min_length eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id @@ -574,6 +582,13 @@ class GenerationMixin: processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)) if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0: processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)) + if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0: + if self.config.is_encoder_decoder: + processors.append(EncoderNoRepeatNGramLogitsProcessor(encoder_no_repeat_ngram_size, encoder_input_ids)) + else: + raise ValueError( + "It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only architecture" + ) if bad_words_ids is not None: processors.append(NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)) if min_length is not None and eos_token_id is not None and min_length > -1: @@ -601,6 +616,7 @@ class GenerationMixin: eos_token_id: Optional[int] = None, length_penalty: Optional[float] = None, no_repeat_ngram_size: Optional[int] = None, + encoder_no_repeat_ngram_size: Optional[int] = None, num_return_sequences: Optional[int] = None, decoder_start_token_id: Optional[int] = None, use_cache: Optional[bool] = None, @@ -661,6 +677,9 @@ class GenerationMixin: sequences. no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. + encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0): + If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the + ``decoder_input_ids``. bad_words_ids(:obj:`List[List[int]]`, `optional`): List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use :obj:`tokenizer(bad_word, @@ -820,6 +839,9 @@ class GenerationMixin: logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") pad_token_id = eos_token_id + # Storing encoder_input_ids for logits_processor that could use them + encoder_input_ids = input_ids if self.config.is_encoder_decoder else None + if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) @@ -862,6 +884,8 @@ class GenerationMixin: logits_processor = self._get_logits_processor( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, + encoder_input_ids=encoder_input_ids, bad_words_ids=bad_words_ids, min_length=min_length, eos_token_id=eos_token_id, @@ -1638,6 +1662,7 @@ class GenerationMixin: beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + cur_len = cur_len + 1 model_kwargs = self._update_model_kwargs_for_generation( diff --git a/src/transformers/models/blenderbot/configuration_blenderbot.py b/src/transformers/models/blenderbot/configuration_blenderbot.py index 242d5d36f51..4de6a9d12ad 100644 --- a/src/transformers/models/blenderbot/configuration_blenderbot.py +++ b/src/transformers/models/blenderbot/configuration_blenderbot.py @@ -128,6 +128,7 @@ class BlenderbotConfig(PretrainedConfig): pad_token_id=0, bos_token_id=1, eos_token_id=2, + encoder_no_repeat_ngram_size=3, **kwargs ): super().__init__( @@ -136,6 +137,7 @@ class BlenderbotConfig(PretrainedConfig): eos_token_id=eos_token_id, is_encoder_decoder=is_encoder_decoder, decoder_start_token_id=decoder_start_token_id, + encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size, **kwargs, ) diff --git a/tests/test_generation_logits_process.py b/tests/test_generation_logits_process.py index 1aa2941047f..315417df335 100644 --- a/tests/test_generation_logits_process.py +++ b/tests/test_generation_logits_process.py @@ -27,6 +27,7 @@ if is_torch_available(): import torch.nn.functional as F from transformers.generation_logits_process import ( + EncoderNoRepeatNGramLogitsProcessor, HammingDiversityLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, @@ -208,6 +209,68 @@ class LogitsProcessorTest(unittest.TestCase): torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]] ) + def test_encoder_no_repeat_ngram_dist_processor(self): + vocab_size = 3 + num_beams = 2 + batch_size = 1 + + encoder_input_ids = torch.tensor([1, 2, 1, 1], device=torch_device, dtype=torch.long) + + input_ids = torch.tensor([[1, 2, 1], [8, 0, 2]], device=torch_device, dtype=torch.long) + scores = self._get_uniform_logits(batch_size * num_beams, vocab_size) + + no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids) + no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids) + + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone()) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone()) + + # 2-gram would forbid 1st and 2nd token at 1st beam and 1st token (0) at 2nd beam + self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [False, True, False]]) + + # 3-gram would forbid 1st token at 1st beam and no token at 2nd beam + self.assertListEqual( + torch.isinf(filtered_scores_3_gram).tolist(), [[False, True, False], [False, False, False]] + ) + + # Batched input + vocab_size = 3 + num_beams = 2 + batch_size = 2 + encoder_input_ids = torch.tensor([[1, 2, 1, 1], [0, 0, 2, 1]], device=torch_device, dtype=torch.long) + + input_ids = torch.tensor([[1, 2, 1], [1, 0, 2], [0, 0, 0], [0, 2, 2]], device=torch_device, dtype=torch.long) + scores = self._get_uniform_logits(batch_size * num_beams, vocab_size) + + no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids) + no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids) + + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone()) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone()) + + # 2gram + # Batch 1 + # - Beam 1: tokens (1, 2) forbidden + # - Beam 2: tokens (1) forbidden + # Batch 2 + # - Beam 1: tokens (0, 2) forbidden + # - Beam 2: tokens (1) forbidden + self.assertListEqual( + torch.isinf(filtered_scores_2_gram).tolist(), + [[False, True, True], [False, True, False], [True, False, True], [False, True, False]], + ) + + # Batch 1 + # - Beam 1: tokens (1) forbidden + # - Beam 2: tokens () forbidden + # Batch 2 + # - Beam 1: tokens (2) forbidden + # - Beam 2: tokens () forbidden + self.assertListEqual( + torch.isinf(filtered_scores_3_gram).tolist(), + [[False, True, False], [False, False, False], [False, False, True], [False, False, False]], + ) + def test_no_bad_words_dist_processor(self): vocab_size = 5 batch_size = 2 diff --git a/tests/test_pipelines_conversational.py b/tests/test_pipelines_conversational.py index 69fb88e480b..276c801d64d 100644 --- a/tests/test_pipelines_conversational.py +++ b/tests/test_pipelines_conversational.py @@ -276,6 +276,47 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas self.assertEqual(result.past_user_inputs[1], "Is it an action movie?") self.assertEqual(result.generated_responses[1], "It's a comedy.") + @require_torch + @slow + def test_integration_torch_conversation_blenderbot_400M(self): + tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") + model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-400M-distill") + nlp = ConversationalPipeline(model=model, tokenizer=tokenizer) + + conversation_1 = Conversation("hello") + result = nlp( + conversation_1, + ) + self.assertEqual( + result.generated_responses[0], + # ParlAI implementation output, we have a different one, but it's our + # second best, you can check by using num_return_sequences=10 + # " Hello! How are you? I'm just getting ready to go to work, how about you?", + " Hello! How are you doing today? I just got back from a walk with my dog.", + ) + + conversation_1 = Conversation(" Lasagne hello") + result = nlp(conversation_1, encoder_no_repeat_ngram_size=3) + self.assertEqual( + result.generated_responses[0], + " Lasagne is my favorite Italian dish. Do you like lasagne?", + ) + + conversation_1 = Conversation( + "Lasagne hello Lasagne is my favorite Italian dish. Do you like lasagne? I like lasagne." + ) + result = nlp( + conversation_1, + encoder_no_repeat_ngram_size=3, + ) + self.assertEqual( + result.generated_responses[0], + # ParlAI implementation output, we have a different one, but it's our + # second best, you can check by using num_return_sequences=10 + # " Hello! How are you? I'm just getting ready to go to work, how about you?", + " Lasagne is a traditional Italian dish consisting of a yeasted flatbread typically topped with tomato sauce and cheese.", + ) + @require_torch @slow def test_integration_torch_conversation_encoder_decoder(self):